From 880dd8bf6aedcb21642fe6703b76d2f6a3c3aaaf Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Mon, 16 Mar 2020 14:20:54 +0530 Subject: [PATCH] jasper asr first commit --- .gitignore | 110 +++++++++++++++++++++++++++++++++++++++++++++ README.md | 36 +++++++++++++++ jasper/__init__.py | 1 + jasper/__main__.py | 32 +++++++++++++ jasper/asr.py | 100 +++++++++++++++++++++++++++++++++++++++++ setup.py | 17 +++++++ 6 files changed, 296 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 jasper/__init__.py create mode 100644 jasper/__main__.py create mode 100644 jasper/asr.py create mode 100644 setup.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..aab7ea0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,110 @@ + +# Created by https://www.gitignore.io/api/python +# Edit at https://www.gitignore.io/?templates=python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# End of https://www.gitignore.io/api/python diff --git a/README.md b/README.md new file mode 100644 index 0000000..a573150 --- /dev/null +++ b/README.md @@ -0,0 +1,36 @@ +# Jasper ASR + +[![image](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black) + +> Generates text from speech audio +--- + +# Table of Contents + +* [Features](#features) +* [Installation](#installation) +* [Usage](#usage) + +# Features + +* ASR using Jasper (from [NemoToolkit](https://github.com/NVIDIA/NeMo) ) + + +# Installation +To install the packages and its dependencies run. +```bash +python setup.py install +``` +or with pip +```bash +pip install . +``` + +The installation should work on Python 3.6 or newer. Untested on Python 2.7 + +# Usage +```python +from jasper.asr import JasperASR +asr_model = JasperASR("/path/to/model_config_yaml","/path/to/encoder_checkpoint","/path/to/decoder_checkpoint") # Loads the models +TEXT = asr_model.transcribe(wav_data) # Returns the text spoken in the wav +``` diff --git a/jasper/__init__.py b/jasper/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/jasper/__init__.py @@ -0,0 +1 @@ + diff --git a/jasper/__main__.py b/jasper/__main__.py new file mode 100644 index 0000000..04431b3 --- /dev/null +++ b/jasper/__main__.py @@ -0,0 +1,32 @@ +import os +import argparse +from pathlib import Path +from .asr import JasperASR + +MODEL_YAML = os.environ.get("JASPER_MODEL_CONFIG", "/models/jasper/jasper10x5dr.yaml") +CHECKPOINT_ENCODER = os.environ.get( + "JASPER_ENCODER_CHECKPOINT", "/models/jasper/JasperEncoder-STEP-265520.pt" +) +CHECKPOINT_DECODER = os.environ.get( + "JASPER_DECODER_CHECKPOINT", "/models/jasper/JasperDecoderForCTC-STEP-265520.pt" +) + + +def arg_parser(): + prog = Path(__file__).stem + parser = argparse.ArgumentParser( + prog=prog, description=f"generates transcription of the audio_file" + ) + parser.add_argument( + "--audio_file", + type=Path, + help="audio file(16khz 1channel int16 wav) to transcribe", + ) + return parser + + +def main(): + parser = arg_parser() + args = parser.parse_args() + jasper_asr = JasperASR(MODEL_YAML, CHECKPOINT_ENCODER, CHECKPOINT_DECODER) + jasper_asr.transcribe_file(args.audio_file) diff --git a/jasper/asr.py b/jasper/asr.py new file mode 100644 index 0000000..2fcc898 --- /dev/null +++ b/jasper/asr.py @@ -0,0 +1,100 @@ +import tempfile +from ruamel.yaml import YAML +import json +import nemo +import nemo.collections.asr as nemo_asr + +logging = nemo.logging + +WORK_DIR = "/tmp" + + +class JasperASR(object): + """docstring for JasperASR.""" + + def __init__(self, model_yaml, encoder_checkpoint, decoder_checkpoint): + super(JasperASR, self).__init__() + # Read model YAML + yaml = YAML(typ="safe") + with open(model_yaml) as f: + jasper_model_definition = yaml.load(f) + self.neural_factory = nemo.core.NeuralModuleFactory( + placement=nemo.core.DeviceType.GPU, backend=nemo.core.Backend.PyTorch + ) + self.labels = jasper_model_definition["labels"] + self.data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor() + self.jasper_encoder = nemo_asr.JasperEncoder( + jasper=jasper_model_definition["JasperEncoder"]["jasper"], + activation=jasper_model_definition["JasperEncoder"]["activation"], + feat_in=jasper_model_definition["AudioToMelSpectrogramPreprocessor"][ + "features" + ], + ) + self.jasper_encoder.restore_from(encoder_checkpoint, local_rank=0) + self.jasper_decoder = nemo_asr.JasperDecoderForCTC( + feat_in=1024, num_classes=len(self.labels) + ) + self.jasper_decoder.restore_from(decoder_checkpoint, local_rank=0) + self.greedy_decoder = nemo_asr.GreedyCTCDecoder() + + def transcribe(self, audio_data, greedy=True): + audio_file = tempfile.NamedTemporaryFile( + dir=WORK_DIR, prefix="jasper_audio.", delete=False + ) + audio_file.write(audio_data) + audio_file.close() + audio_file_path = audio_file.name + manifest = {"audio_filepath": audio_file_path, "duration": 60, "text": "todo"} + manifest_file = tempfile.NamedTemporaryFile( + dir=WORK_DIR, prefix="jasper_manifest.", delete=False, mode="w" + ) + manifest_file.write(json.dumps(manifest)) + manifest_file.close() + manifest_file_path = manifest_file.name + data_layer = nemo_asr.AudioToTextDataLayer( + shuffle=False, + manifest_filepath=manifest_file_path, + labels=self.labels, + batch_size=1, + ) + + # Define inference DAG + audio_signal, audio_signal_len, _, _ = data_layer() + processed_signal, processed_signal_len = self.data_preprocessor( + input_signal=audio_signal, length=audio_signal_len + ) + encoded, encoded_len = self.jasper_encoder( + audio_signal=processed_signal, length=processed_signal_len + ) + log_probs = self.jasper_decoder(encoder_output=encoded) + predictions = self.greedy_decoder(log_probs=log_probs) + + # if ENABLE_NGRAM: + # logging.info('Running with beam search') + # beam_predictions = beam_search_with_lm(log_probs=log_probs, log_probs_length=encoded_len) + # eval_tensors = [beam_predictions] + + # if greedy: + eval_tensors = [predictions] + + tensors = self.neural_factory.infer(tensors=eval_tensors) + if greedy: + from nemo.collections.asr.helpers import post_process_predictions + + prediction = post_process_predictions(tensors[0], self.labels) + else: + prediction = tensors[0][0][0][0][1] + prediction_text = ". ".join(prediction) + return prediction_text + + def transcribe_file(self, audio_file): + tscript_file_path = audio_file.with_suffix(".txt") + audio_file_path = str(audio_file) + try: + with open(audio_file_path, "rb") as af: + audio_data = af.read() + transcription = self.transcribe(audio_data) + with open(tscript_file_path, "w") as tf: + tf.write(transcription) + except BaseException as e: + logging.info(f"an error occurred during transcrption: {e}") diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..308ce60 --- /dev/null +++ b/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup + +setup( + name="jasper-asr", + version="0.1", + description="Tool to get gcp alignments of tts-data", + url="http://github.com/malarinv/jasper-asr", + author="Malar Kannan", + author_email="malarkannan.invention@gmail.com", + license="MIT", + install_requires=[ + "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit" + ], + packages=["."], + entry_points={"console_scripts": ["jasper_transcribe = jasper.__main__:main"]}, + zip_safe=False, +)