plume-asr/plume/utils/transcribe.py

185 lines
5.5 KiB
Python
Raw Normal View History

2021-02-23 14:13:33 +00:00
import os
import logging
from io import BytesIO
from pathlib import Path
from functools import lru_cache
import typer
# import rpyc
# from tqdm import tqdm
# from pydub import AudioSegment
# from pydub.silence import split_on_silence
from plume.utils import lazy_module, lazy_callable
rpyc = lazy_module('rpyc')
AudioSegment = lazy_callable('pydub.AudioSegment')
split_on_silence = lazy_callable('pydub.silence.split_on_silence')
app = typer.Typer()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
ASR_RPYC_HOST = os.environ.get("JASR_RPYC_HOST", "localhost")
ASR_RPYC_PORT = int(os.environ.get("ASR_RPYC_PORT", "8044"))
TRITON_ASR_MODEL = os.environ.get("TRITON_ASR_MODEL", "slu_wav2vec2")
TRITON_GRPC_ASR_HOST = os.environ.get("TRITON_GRPC_ASR_HOST", "localhost")
TRITON_GRPC_ASR_PORT = int(os.environ.get("TRITON_GRPC_ASR_PORT", "8001"))
@lru_cache()
def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT):
logger.info(f"connecting to asr server at {asr_host}:{asr_port}")
try:
asr = rpyc.connect(asr_host, asr_port).root
logger.info(f"connected to asr server successfully")
except ConnectionRefusedError:
raise Exception("env-var JASPER_ASR_RPYC_HOST invalid")
def audio_prep(aud_seg):
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
return asr_seg
return asr.transcribe, audio_prep
def triton_transcribe_grpc_gen(
asr_host=TRITON_GRPC_ASR_HOST,
asr_port=TRITON_GRPC_ASR_PORT,
asr_model=TRITON_ASR_MODEL,
method="chunked",
chunk_msec=5000,
sil_msec=500,
# overlap=False,
sep=" ",
):
from tritonclient.utils import np_to_triton_dtype
import tritonclient.grpc as grpcclient
import numpy as np
sup_meth = ["chunked", "silence", "whole"]
if method not in sup_meth:
meths = "|".join(sup_meth)
raise Exception(f"unsupported method {method}. pick one of {meths}")
client = grpcclient.InferenceServerClient(f"{asr_host}:{asr_port}")
def transcriber(aud_seg):
af = BytesIO()
aud_seg.export(af, format="wav")
input_audio_bytes = af.getvalue()
input_audio_data = np.array([input_audio_bytes])
inputs = [
grpcclient.InferInput(
"INPUT_AUDIO",
input_audio_data.shape,
np_to_triton_dtype(input_audio_data.dtype),
)
]
inputs[0].set_data_from_numpy(input_audio_data)
outputs = [grpcclient.InferRequestedOutput("OUTPUT_TEXT")]
response = client.infer(asr_model, inputs, request_id=str(1), outputs=outputs)
transcript = response.as_numpy("OUTPUT_TEXT")[0]
return transcript.decode("utf-8")
def chunked_transcriber(aud_seg):
if method == "silence":
sil_chunks = split_on_silence(
aud_seg,
min_silence_len=sil_msec,
silence_thresh=-50,
keep_silence=500,
)
chunks = [sc for c in sil_chunks for sc in c[::chunk_msec]]
else:
chunks = aud_seg[::chunk_msec]
# if overlap:
# chunks = [
# aud_seg[start, end]
# for start, end in range(0, int(aud_seg.duration_seconds * 1000, 1000))
# ]
# pass
transcript_list = []
sil_pad = AudioSegment.silent(duration=sil_msec)
for seg in chunks:
t_seg = sil_pad + seg + sil_pad
c_transcript = transcriber(t_seg)
transcript_list.append(c_transcript)
transcript = sep.join(transcript_list)
return transcript
def audio_prep(aud_seg):
asr_seg = aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
return asr_seg
whole_transcriber = transcriber if method == "whole" else chunked_transcriber
return whole_transcriber, audio_prep
@app.command()
def file(audio_file: Path, write_file: bool = False, chunked=True):
from pydub import AudioSegment
aseg = AudioSegment.from_file(audio_file)
transcriber, prep = triton_transcribe_grpc_gen()
transcription = transcriber(prep(aseg))
typer.echo(transcription)
if write_file:
tscript_file_path = audio_file.with_suffix(".txt")
with open(tscript_file_path, "w") as tf:
tf.write(transcription)
@app.command()
def benchmark(audio_file: Path):
from pydub import AudioSegment
transcriber, audio_prep = transcribe_rpyc_gen()
file_seg = AudioSegment.from_file(audio_file)
aud_seg = audio_prep(file_seg)
def timeinfo():
from timeit import Timer
timer = Timer(lambda: transcriber(aud_seg))
number = 100
repeat = 10
time_taken = timer.repeat(repeat, number=number)
best = min(time_taken) * 1000 / number
print(f"{number} loops, best of {repeat}: {best:.3f} msec per loop")
timeinfo()
import time
time.sleep(5)
transcriber, audio_prep = triton_transcribe_grpc_gen()
aud_seg = audio_prep(file_seg)
def timeinfo():
from timeit import Timer
timer = Timer(lambda: transcriber(aud_seg))
number = 100
repeat = 10
time_taken = timer.repeat(repeat, number=number)
best = min(time_taken) * 1000 / number
print(f"{number} loops, best of {repeat}: {best:.3f} msec per loop")
timeinfo()
def main():
app()
if __name__ == "__main__":
main()