diff --git a/jasper/__main__.py b/jasper/__main__.py index 04431b3..42b383a 100644 --- a/jasper/__main__.py +++ b/jasper/__main__.py @@ -10,6 +10,7 @@ CHECKPOINT_ENCODER = os.environ.get( CHECKPOINT_DECODER = os.environ.get( "JASPER_DECODER_CHECKPOINT", "/models/jasper/JasperDecoderForCTC-STEP-265520.pt" ) +KEN_LM = os.environ.get("JASPER_KEN_LM", "/models/jasper/kenlm.pt") def arg_parser(): @@ -18,15 +19,42 @@ def arg_parser(): prog=prog, description=f"generates transcription of the audio_file" ) parser.add_argument( - "--audio_file", + "audio_file", type=Path, help="audio file(16khz 1channel int16 wav) to transcribe", ) + parser.add_argument( + "--greedy", type=bool, default=False, help="enables greedy decoding" + ) + parser.add_argument( + "--model_yaml", + type=Path, + default=Path(MODEL_YAML), + help="model config yaml file", + ) + parser.add_argument( + "--encoder_checkpoint", + type=Path, + default=Path(CHECKPOINT_ENCODER), + help="encoder checkpoint weights file", + ) + parser.add_argument( + "--decoder_checkpoint", + type=Path, + default=Path(CHECKPOINT_DECODER), + help="decoder checkpoint weights file", + ) + parser.add_argument( + "--language_model", type=Path, default=None, help="kenlm language model file" + ) return parser def main(): parser = arg_parser() args = parser.parse_args() - jasper_asr = JasperASR(MODEL_YAML, CHECKPOINT_ENCODER, CHECKPOINT_DECODER) - jasper_asr.transcribe_file(args.audio_file) + args_dict = vars(args) + audio_file = args_dict.pop("audio_file") + greedy = args_dict.pop("greedy") + jasper_asr = JasperASR(**args_dict) + jasper_asr.transcribe_file(audio_file, greedy) diff --git a/jasper/asr.py b/jasper/asr.py index 2fcc898..de3d78f 100644 --- a/jasper/asr.py +++ b/jasper/asr.py @@ -1,8 +1,11 @@ +import os import tempfile from ruamel.yaml import YAML import json import nemo import nemo.collections.asr as nemo_asr +import wave +from nemo.collections.asr.helpers import post_process_predictions logging = nemo.logging @@ -12,7 +15,9 @@ WORK_DIR = "/tmp" class JasperASR(object): """docstring for JasperASR.""" - def __init__(self, model_yaml, encoder_checkpoint, decoder_checkpoint): + def __init__( + self, model_yaml, encoder_checkpoint, decoder_checkpoint, language_model=None + ): super(JasperASR, self).__init__() # Read model YAML yaml = YAML(typ="safe") @@ -36,14 +41,30 @@ class JasperASR(object): ) self.jasper_decoder.restore_from(decoder_checkpoint, local_rank=0) self.greedy_decoder = nemo_asr.GreedyCTCDecoder() + self.beam_search_with_lm = None + if language_model: + self.beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM( + vocab=self.labels, + beam_width=64, + alpha=2.0, + beta=1.0, + lm_path=language_model, + num_cpus=max(os.cpu_count(), 1), + ) def transcribe(self, audio_data, greedy=True): audio_file = tempfile.NamedTemporaryFile( dir=WORK_DIR, prefix="jasper_audio.", delete=False ) - audio_file.write(audio_data) + # audio_file.write(audio_data) audio_file.close() audio_file_path = audio_file.name + wf = wave.open(audio_file_path, "w") + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(16000) + wf.writeframesraw(audio_data) + wf.close() manifest = {"audio_filepath": audio_file_path, "duration": 60, "text": "todo"} manifest_file = tempfile.NamedTemporaryFile( dir=WORK_DIR, prefix="jasper_manifest.", delete=False, mode="w" @@ -69,32 +90,32 @@ class JasperASR(object): log_probs = self.jasper_decoder(encoder_output=encoded) predictions = self.greedy_decoder(log_probs=log_probs) - # if ENABLE_NGRAM: - # logging.info('Running with beam search') - # beam_predictions = beam_search_with_lm(log_probs=log_probs, log_probs_length=encoded_len) - # eval_tensors = [beam_predictions] - - # if greedy: - eval_tensors = [predictions] + if greedy: + eval_tensors = [predictions] + else: + if self.beam_search_with_lm: + logging.info("Running with beam search") + beam_predictions = self.beam_search_with_lm( + log_probs=log_probs, log_probs_length=encoded_len + ) + eval_tensors = [beam_predictions] + else: + logging.info( + "language_model not specified. falling back to greedy decoding." + ) + eval_tensors = [predictions] tensors = self.neural_factory.infer(tensors=eval_tensors) - if greedy: - from nemo.collections.asr.helpers import post_process_predictions - - prediction = post_process_predictions(tensors[0], self.labels) - else: - prediction = tensors[0][0][0][0][1] + prediction = post_process_predictions(tensors[0], self.labels) prediction_text = ". ".join(prediction) return prediction_text - def transcribe_file(self, audio_file): + def transcribe_file(self, audio_file, *args, **kwargs): tscript_file_path = audio_file.with_suffix(".txt") audio_file_path = str(audio_file) - try: - with open(audio_file_path, "rb") as af: - audio_data = af.read() - transcription = self.transcribe(audio_data) - with open(tscript_file_path, "w") as tf: - tf.write(transcription) - except BaseException as e: - logging.info(f"an error occurred during transcrption: {e}") + with wave.open(audio_file_path, "r") as af: + frame_count = af.getnframes() + audio_data = af.readframes(frame_count) + transcription = self.transcribe(audio_data, *args, **kwargs) + with open(tscript_file_path, "w") as tf: + tf.write(transcription) diff --git a/jasper/server.py b/jasper/server.py new file mode 100644 index 0000000..09ade5b --- /dev/null +++ b/jasper/server.py @@ -0,0 +1,56 @@ +import os +import logging + +import rpyc +from rpyc.utils.server import ThreadedServer + +from .asr import JasperASR + + +MODEL_YAML = os.environ.get("JASPER_MODEL_CONFIG", "/models/jasper/jasper10x5dr.yaml") +CHECKPOINT_ENCODER = os.environ.get( + "JASPER_ENCODER_CHECKPOINT", "/models/jasper/JasperEncoder-STEP-265520.pt" +) +CHECKPOINT_DECODER = os.environ.get( + "JASPER_DECODER_CHECKPOINT", "/models/jasper/JasperDecoderForCTC-STEP-265520.pt" +) +KEN_LM = os.environ.get("JASPER_KEN_LM", None) + +asr_recognizer = JasperASR(MODEL_YAML, CHECKPOINT_ENCODER, CHECKPOINT_DECODER, KEN_LM) + + +class ASRService(rpyc.Service): + def on_connect(self, conn): + # code that runs when a connection is created + # (to init the service, if needed) + pass + + def on_disconnect(self, conn): + # code that runs after the connection has already closed + # (to finalize the service, if needed) + pass + + def exposed_transcribe(self, utterance: bytes): # this is an exposed method + speech_audio = asr_recognizer.transcribe(utterance) + return speech_audio + + def exposed_transcribe_cb( + self, utterance: bytes, respond + ): # this is an exposed method + speech_audio = asr_recognizer.transcribe(utterance) + respond(speech_audio) + + +def main(): + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + port = int(os.environ.get("ASR_RPYC_PORT", "8044")) + logging.info("starting tts server...") + t = ThreadedServer(ASRService, port=port) + t.start() + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 308ce60..b3ecc83 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,12 @@ from setuptools import setup +requirements = [ + "ruamel.yaml", + "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit", +] + +extra_requirements = {"server": ["rpyc==4.1.4"]} + setup( name="jasper-asr", version="0.1", @@ -8,10 +15,14 @@ setup( author="Malar Kannan", author_email="malarkannan.invention@gmail.com", license="MIT", - install_requires=[ - "nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo.git@09e3ba4dfe333f86d6c5c1048e07210924294be9#egg=nemo_toolkit" - ], + install_requires=requirements, + extras_require=extra_requirements, packages=["."], - entry_points={"console_scripts": ["jasper_transcribe = jasper.__main__:main"]}, + entry_points={ + "console_scripts": [ + "jasper_transcribe = jasper.__main__:main", + "asr_rpyc_server = jasper.server:main", + ] + }, zip_safe=False, )