diff --git a/jasper/server.py b/jasper/server.py index 09ade5b..163a022 100644 --- a/jasper/server.py +++ b/jasper/server.py @@ -5,21 +5,13 @@ import rpyc from rpyc.utils.server import ThreadedServer from .asr import JasperASR - - -MODEL_YAML = os.environ.get("JASPER_MODEL_CONFIG", "/models/jasper/jasper10x5dr.yaml") -CHECKPOINT_ENCODER = os.environ.get( - "JASPER_ENCODER_CHECKPOINT", "/models/jasper/JasperEncoder-STEP-265520.pt" -) -CHECKPOINT_DECODER = os.environ.get( - "JASPER_DECODER_CHECKPOINT", "/models/jasper/JasperDecoderForCTC-STEP-265520.pt" -) -KEN_LM = os.environ.get("JASPER_KEN_LM", None) - -asr_recognizer = JasperASR(MODEL_YAML, CHECKPOINT_ENCODER, CHECKPOINT_DECODER, KEN_LM) +from .utils import arg_parser class ASRService(rpyc.Service): + def __init__(self, asr_recognizer): + self.asr = asr_recognizer + def on_connect(self, conn): # code that runs when a connection is created # (to init the service, if needed) @@ -31,24 +23,33 @@ class ASRService(rpyc.Service): pass def exposed_transcribe(self, utterance: bytes): # this is an exposed method - speech_audio = asr_recognizer.transcribe(utterance) + speech_audio = self.asr.transcribe(utterance) return speech_audio def exposed_transcribe_cb( self, utterance: bytes, respond ): # this is an exposed method - speech_audio = asr_recognizer.transcribe(utterance) + speech_audio = self.asr.transcribe(utterance) respond(speech_audio) def main(): + parser = arg_parser('jasper_transcribe') + parser.description = 'jasper asr rpyc server' + parser.add_argument( + "--port", type=int, default=int(os.environ.get("ASR_RPYC_PORT", "8044")), help="port to listen on" + ) + args = parser.parse_args() + args_dict = vars(args) + port = args_dict.pop("port") + jasper_asr = JasperASR(**args_dict) + service = ASRService(jasper_asr) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - port = int(os.environ.get("ASR_RPYC_PORT", "8044")) - logging.info("starting tts server...") - t = ThreadedServer(ASRService, port=port) + logging.info("starting asr server...") + t = ThreadedServer(service, port=port) t.start() diff --git a/jasper/transcribe.py b/jasper/transcribe.py new file mode 100644 index 0000000..d58fe9d --- /dev/null +++ b/jasper/transcribe.py @@ -0,0 +1,22 @@ +from pathlib import Path +from .asr import JasperASR +from .utils import arg_parser + + +def main(): + parser = arg_parser('jasper_transcribe') + parser.description = 'transcribe audio file to text' + parser.add_argument( + "audio_file", + type=Path, + help="audio file(16khz 1channel int16 wav) to transcribe", + ) + parser.add_argument( + "--greedy", type=bool, default=False, help="enables greedy decoding" + ) + args = parser.parse_args() + args_dict = vars(args) + audio_file = args_dict.pop("audio_file") + greedy = args_dict.pop("greedy") + jasper_asr = JasperASR(**args_dict) + jasper_asr.transcribe_file(audio_file, greedy) diff --git a/jasper/__main__.py b/jasper/utils.py similarity index 62% rename from jasper/__main__.py rename to jasper/utils.py index 42b383a..5f5ed3e 100644 --- a/jasper/__main__.py +++ b/jasper/utils.py @@ -1,7 +1,6 @@ import os import argparse from pathlib import Path -from .asr import JasperASR MODEL_YAML = os.environ.get("JASPER_MODEL_CONFIG", "/models/jasper/jasper10x5dr.yaml") CHECKPOINT_ENCODER = os.environ.get( @@ -13,18 +12,9 @@ CHECKPOINT_DECODER = os.environ.get( KEN_LM = os.environ.get("JASPER_KEN_LM", "/models/jasper/kenlm.pt") -def arg_parser(): - prog = Path(__file__).stem +def arg_parser(prog): parser = argparse.ArgumentParser( - prog=prog, description=f"generates transcription of the audio_file" - ) - parser.add_argument( - "audio_file", - type=Path, - help="audio file(16khz 1channel int16 wav) to transcribe", - ) - parser.add_argument( - "--greedy", type=bool, default=False, help="enables greedy decoding" + prog=prog, description=f"convert speech to text" ) parser.add_argument( "--model_yaml", @@ -48,13 +38,3 @@ def arg_parser(): "--language_model", type=Path, default=None, help="kenlm language model file" ) return parser - - -def main(): - parser = arg_parser() - args = parser.parse_args() - args_dict = vars(args) - audio_file = args_dict.pop("audio_file") - greedy = args_dict.pop("greedy") - jasper_asr = JasperASR(**args_dict) - jasper_asr.transcribe_file(audio_file, greedy) diff --git a/setup.py b/setup.py index b3ecc83..b18df24 100644 --- a/setup.py +++ b/setup.py @@ -20,8 +20,8 @@ setup( packages=["."], entry_points={ "console_scripts": [ - "jasper_transcribe = jasper.__main__:main", - "asr_rpyc_server = jasper.server:main", + "jasper_transcribe = jasper.transcribe:main", + "jasper_asr_rpyc_server = jasper.server:main", ] }, zip_safe=False,