import os import logging from pathlib import Path # from rpyc.utils.server import ThreadedServer import typer from ...utils.serve import ASRService from plume.utils import lazy_callable # from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR # from .asr import Wav2Vec2ASR ThreadedServer = lazy_callable("rpyc.utils.server.ThreadedServer") Wav2Vec2TransformersASR = lazy_callable( "plume.models.wav2vec2_transformers.asr.Wav2Vec2TransformersASR" ) app = typer.Typer() @app.command() def rpyc( w2v_path: Path = "/path/to/base.pt", ctc_path: Path = "/path/to/ctc.pt", target_dict_path: Path = "/path/to/dict.ltr.txt", port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")), ): w2vasr = Wav2Vec2TransformersASR(ctc_path, w2v_path, target_dict_path) service = ASRService(w2vasr) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logging.info("starting asr server...") t = ThreadedServer(service, port=port) t.start() @app.command() def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))): ctc_path = model_dir / Path("ctc.pt") w2v_path = model_dir / Path("base.pt") target_dict_path = model_dir / Path("dict.ltr.txt") rpyc(w2v_path, ctc_path, target_dict_path, port) def main(): app() if __name__ == "__main__": main()