plume-asr/plume/models/wav2vec2/serve.py

54 lines
1.4 KiB
Python

import os
import logging
from pathlib import Path
# from rpyc.utils.server import ThreadedServer
import typer
from ...utils.serve import ASRService
from plume.utils import lazy_callable
# from .asr import Wav2Vec2ASR
ThreadedServer = lazy_callable('rpyc.utils.server.ThreadedServer')
Wav2Vec2ASR = lazy_callable('plume.models.wav2vec2.asr.Wav2Vec2ASR')
app = typer.Typer()
@app.command()
def rpyc(
w2v_path: Path = "/path/to/base.pt",
ctc_path: Path = "/path/to/ctc.pt",
target_dict_path: Path = "/path/to/dict.ltr.txt",
port: int = int(os.environ.get("ASR_RPYC_PORT", "8044")),
):
for p in [w2v_path, ctc_path, target_dict_path]:
if not p.exists():
logging.info(f"{p} doesn't exists")
return
w2vasr = Wav2Vec2ASR(str(ctc_path), str(w2v_path), str(target_dict_path))
service = ASRService(w2vasr)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logging.info("starting asr server...")
t = ThreadedServer(service, port=port)
t.start()
@app.command()
def rpyc_dir(model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))):
ctc_path = model_dir / Path("ctc.pt")
w2v_path = model_dir / Path("base.pt")
target_dict_path = model_dir / Path("dict.ltr.txt")
rpyc(w2v_path, ctc_path, target_dict_path, port)
def main():
app()
if __name__ == "__main__":
main()