mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-06-13 12:32:08 +00:00
refactored arg parsing to take server cli args
This commit is contained in:
@@ -5,21 +5,13 @@ import rpyc
|
||||
from rpyc.utils.server import ThreadedServer
|
||||
|
||||
from .asr import JasperASR
|
||||
|
||||
|
||||
MODEL_YAML = os.environ.get("JASPER_MODEL_CONFIG", "/models/jasper/jasper10x5dr.yaml")
|
||||
CHECKPOINT_ENCODER = os.environ.get(
|
||||
"JASPER_ENCODER_CHECKPOINT", "/models/jasper/JasperEncoder-STEP-265520.pt"
|
||||
)
|
||||
CHECKPOINT_DECODER = os.environ.get(
|
||||
"JASPER_DECODER_CHECKPOINT", "/models/jasper/JasperDecoderForCTC-STEP-265520.pt"
|
||||
)
|
||||
KEN_LM = os.environ.get("JASPER_KEN_LM", None)
|
||||
|
||||
asr_recognizer = JasperASR(MODEL_YAML, CHECKPOINT_ENCODER, CHECKPOINT_DECODER, KEN_LM)
|
||||
from .utils import arg_parser
|
||||
|
||||
|
||||
class ASRService(rpyc.Service):
|
||||
def __init__(self, asr_recognizer):
|
||||
self.asr = asr_recognizer
|
||||
|
||||
def on_connect(self, conn):
|
||||
# code that runs when a connection is created
|
||||
# (to init the service, if needed)
|
||||
@@ -31,24 +23,33 @@ class ASRService(rpyc.Service):
|
||||
pass
|
||||
|
||||
def exposed_transcribe(self, utterance: bytes): # this is an exposed method
|
||||
speech_audio = asr_recognizer.transcribe(utterance)
|
||||
speech_audio = self.asr.transcribe(utterance)
|
||||
return speech_audio
|
||||
|
||||
def exposed_transcribe_cb(
|
||||
self, utterance: bytes, respond
|
||||
): # this is an exposed method
|
||||
speech_audio = asr_recognizer.transcribe(utterance)
|
||||
speech_audio = self.asr.transcribe(utterance)
|
||||
respond(speech_audio)
|
||||
|
||||
|
||||
def main():
|
||||
parser = arg_parser('jasper_transcribe')
|
||||
parser.description = 'jasper asr rpyc server'
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=int(os.environ.get("ASR_RPYC_PORT", "8044")), help="port to listen on"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args_dict = vars(args)
|
||||
port = args_dict.pop("port")
|
||||
jasper_asr = JasperASR(**args_dict)
|
||||
service = ASRService(jasper_asr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
port = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||
logging.info("starting tts server...")
|
||||
t = ThreadedServer(ASRService, port=port)
|
||||
logging.info("starting asr server...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user