refactored arg parsing to take server cli args
parent
604d0bc87f
commit
f7ebd8e90a
|
|
@ -5,21 +5,13 @@ import rpyc
|
||||||
from rpyc.utils.server import ThreadedServer
|
from rpyc.utils.server import ThreadedServer
|
||||||
|
|
||||||
from .asr import JasperASR
|
from .asr import JasperASR
|
||||||
|
from .utils import arg_parser
|
||||||
|
|
||||||
MODEL_YAML = os.environ.get("JASPER_MODEL_CONFIG", "/models/jasper/jasper10x5dr.yaml")
|
|
||||||
CHECKPOINT_ENCODER = os.environ.get(
|
|
||||||
"JASPER_ENCODER_CHECKPOINT", "/models/jasper/JasperEncoder-STEP-265520.pt"
|
|
||||||
)
|
|
||||||
CHECKPOINT_DECODER = os.environ.get(
|
|
||||||
"JASPER_DECODER_CHECKPOINT", "/models/jasper/JasperDecoderForCTC-STEP-265520.pt"
|
|
||||||
)
|
|
||||||
KEN_LM = os.environ.get("JASPER_KEN_LM", None)
|
|
||||||
|
|
||||||
asr_recognizer = JasperASR(MODEL_YAML, CHECKPOINT_ENCODER, CHECKPOINT_DECODER, KEN_LM)
|
|
||||||
|
|
||||||
|
|
||||||
class ASRService(rpyc.Service):
|
class ASRService(rpyc.Service):
|
||||||
|
def __init__(self, asr_recognizer):
|
||||||
|
self.asr = asr_recognizer
|
||||||
|
|
||||||
def on_connect(self, conn):
|
def on_connect(self, conn):
|
||||||
# code that runs when a connection is created
|
# code that runs when a connection is created
|
||||||
# (to init the service, if needed)
|
# (to init the service, if needed)
|
||||||
|
|
@ -31,24 +23,33 @@ class ASRService(rpyc.Service):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def exposed_transcribe(self, utterance: bytes): # this is an exposed method
|
def exposed_transcribe(self, utterance: bytes): # this is an exposed method
|
||||||
speech_audio = asr_recognizer.transcribe(utterance)
|
speech_audio = self.asr.transcribe(utterance)
|
||||||
return speech_audio
|
return speech_audio
|
||||||
|
|
||||||
def exposed_transcribe_cb(
|
def exposed_transcribe_cb(
|
||||||
self, utterance: bytes, respond
|
self, utterance: bytes, respond
|
||||||
): # this is an exposed method
|
): # this is an exposed method
|
||||||
speech_audio = asr_recognizer.transcribe(utterance)
|
speech_audio = self.asr.transcribe(utterance)
|
||||||
respond(speech_audio)
|
respond(speech_audio)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
parser = arg_parser('jasper_transcribe')
|
||||||
|
parser.description = 'jasper asr rpyc server'
|
||||||
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=int(os.environ.get("ASR_RPYC_PORT", "8044")), help="port to listen on"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args_dict = vars(args)
|
||||||
|
port = args_dict.pop("port")
|
||||||
|
jasper_asr = JasperASR(**args_dict)
|
||||||
|
service = ASRService(jasper_asr)
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
)
|
)
|
||||||
port = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
logging.info("starting asr server...")
|
||||||
logging.info("starting tts server...")
|
t = ThreadedServer(service, port=port)
|
||||||
t = ThreadedServer(ASRService, port=port)
|
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from .asr import JasperASR
|
||||||
|
from .utils import arg_parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = arg_parser('jasper_transcribe')
|
||||||
|
parser.description = 'transcribe audio file to text'
|
||||||
|
parser.add_argument(
|
||||||
|
"audio_file",
|
||||||
|
type=Path,
|
||||||
|
help="audio file(16khz 1channel int16 wav) to transcribe",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--greedy", type=bool, default=False, help="enables greedy decoding"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args_dict = vars(args)
|
||||||
|
audio_file = args_dict.pop("audio_file")
|
||||||
|
greedy = args_dict.pop("greedy")
|
||||||
|
jasper_asr = JasperASR(**args_dict)
|
||||||
|
jasper_asr.transcribe_file(audio_file, greedy)
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .asr import JasperASR
|
|
||||||
|
|
||||||
MODEL_YAML = os.environ.get("JASPER_MODEL_CONFIG", "/models/jasper/jasper10x5dr.yaml")
|
MODEL_YAML = os.environ.get("JASPER_MODEL_CONFIG", "/models/jasper/jasper10x5dr.yaml")
|
||||||
CHECKPOINT_ENCODER = os.environ.get(
|
CHECKPOINT_ENCODER = os.environ.get(
|
||||||
|
|
@ -13,18 +12,9 @@ CHECKPOINT_DECODER = os.environ.get(
|
||||||
KEN_LM = os.environ.get("JASPER_KEN_LM", "/models/jasper/kenlm.pt")
|
KEN_LM = os.environ.get("JASPER_KEN_LM", "/models/jasper/kenlm.pt")
|
||||||
|
|
||||||
|
|
||||||
def arg_parser():
|
def arg_parser(prog):
|
||||||
prog = Path(__file__).stem
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog=prog, description=f"generates transcription of the audio_file"
|
prog=prog, description=f"convert speech to text"
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"audio_file",
|
|
||||||
type=Path,
|
|
||||||
help="audio file(16khz 1channel int16 wav) to transcribe",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--greedy", type=bool, default=False, help="enables greedy decoding"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_yaml",
|
"--model_yaml",
|
||||||
|
|
@ -48,13 +38,3 @@ def arg_parser():
|
||||||
"--language_model", type=Path, default=None, help="kenlm language model file"
|
"--language_model", type=Path, default=None, help="kenlm language model file"
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = arg_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
args_dict = vars(args)
|
|
||||||
audio_file = args_dict.pop("audio_file")
|
|
||||||
greedy = args_dict.pop("greedy")
|
|
||||||
jasper_asr = JasperASR(**args_dict)
|
|
||||||
jasper_asr.transcribe_file(audio_file, greedy)
|
|
||||||
4
setup.py
4
setup.py
|
|
@ -20,8 +20,8 @@ setup(
|
||||||
packages=["."],
|
packages=["."],
|
||||||
entry_points={
|
entry_points={
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
"jasper_transcribe = jasper.__main__:main",
|
"jasper_transcribe = jasper.transcribe:main",
|
||||||
"asr_rpyc_server = jasper.server:main",
|
"jasper_asr_rpyc_server = jasper.server:main",
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue