1. set flake8 max-line to 79
2. update streamlit dep to 1.0 3. add dev optional dep key 4. implement mono diarized dataset generation script 5. enable gpu support on asr transformers inference pipeline 6. use typer logging 7. clean-up annotation ui with everything other than asr-data keys as optional(including plots) 8. implement chunk_transcribe_meta_gen abstraction for asr chunking logic 9. make ui_persist compatibility change for streamlit 1.0 10. add diarize commands(bugfix) 11. add notebooks for diarizationmaster
parent
846f029cf1
commit
79aa5e8578
2
.flake8
2
.flake8
|
|
@ -1,4 +1,4 @@
|
|||
[flake8]
|
||||
exclude = docs
|
||||
ignore = E203, W503
|
||||
max-line-length = 119
|
||||
max-line-length = 79
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
7
setup.py
7
setup.py
|
|
@ -76,7 +76,7 @@ extra_requirements = {
|
|||
"pymongo~=3.10.1",
|
||||
"matplotlib~=3.2.1",
|
||||
"pydub~=0.24.0",
|
||||
"streamlit~=0.58.0",
|
||||
"streamlit~=1.0.0",
|
||||
"natural~=0.2.0",
|
||||
"stringcase~=1.2.0",
|
||||
"google-cloud-speech~=1.3.1",
|
||||
|
|
@ -85,6 +85,11 @@ extra_requirements = {
|
|||
"pyspellchecker~=0.6.2",
|
||||
"google-cloud-texttospeech~=1.0.1",
|
||||
"rangehttpserver~=1.2.0",
|
||||
"streamlit~=1.0.0",
|
||||
],
|
||||
"dev": [
|
||||
"jupyterlab~=3.1.18",
|
||||
"ipykernel~=6.4.1",
|
||||
],
|
||||
"crypto": ["cryptography~=3.4.7"],
|
||||
"train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"],
|
||||
|
|
|
|||
|
|
@ -1,11 +1,17 @@
|
|||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
from tqdm import tqdm
|
||||
import typer
|
||||
from plume.utils.lazy_import import lazy_module
|
||||
from plume.utils.tts import GoogleTTS
|
||||
from plume.utils.transcribe import triton_transcribe_grpc_gen
|
||||
from plume.utils.transcribe import (
|
||||
triton_transcribe_grpc_gen,
|
||||
chunk_transcribe_meta_gen,
|
||||
transcribe_rpyc_gen,
|
||||
)
|
||||
from plume.utils.manifest import asr_manifest_writer
|
||||
from plume.utils.diarize import diarize_audio_gen
|
||||
from plume.utils.extended_path import ExtendedPath
|
||||
|
||||
pydub = lazy_module("pydub")
|
||||
app = typer.Typer()
|
||||
|
|
@ -50,4 +56,46 @@ def asr_dataset(audio_dir: Path, out_dir: Path, model="slu_num_wav2vec2"):
|
|||
"text": transcript,
|
||||
}
|
||||
|
||||
asr_manifest_writer(out_dir / 'manifest.json', data_gen())
|
||||
asr_manifest_writer(out_dir / "manifest.json", data_gen())
|
||||
|
||||
|
||||
@app.command()
|
||||
def mono_diarize_asr_dataset(audio_dir: Path, out_dir: Path):
|
||||
out_wav_dir = out_dir / "wavs"
|
||||
out_wav_dir.mkdir(exist_ok=True, parents=True)
|
||||
diarize_audio = diarize_audio_gen()
|
||||
|
||||
def data_gen():
|
||||
aud_files = list(audio_dir.glob("*/*.mp3")) + list(
|
||||
audio_dir.glob("*/*.wav")
|
||||
)
|
||||
diameta = ExtendedPath(out_dir / "diameta.json")
|
||||
base_transcriber, base_prep = transcribe_rpyc_gen()
|
||||
transcriber, prep = chunk_transcribe_meta_gen(
|
||||
base_transcriber, base_prep, method="chunked"
|
||||
)
|
||||
|
||||
diametadata = []
|
||||
for af in tqdm(aud_files):
|
||||
try:
|
||||
# raise Exception("Test")
|
||||
for dres in diarize_audio(af):
|
||||
sample_fname = dres.pop("sample_fname")
|
||||
out_af = out_wav_dir / sample_fname
|
||||
wav_bytes = dres.pop("wav")
|
||||
out_af.write_bytes(wav_bytes)
|
||||
audio_af = out_af.relative_to(out_dir)
|
||||
aud_seg = dres.pop("wavseg")
|
||||
t_seg = prep(aud_seg)
|
||||
transcript = transcriber(t_seg)
|
||||
diametadata.append(dres)
|
||||
yield {
|
||||
"audio_filepath": str(audio_af),
|
||||
"duration": aud_seg.duration_seconds,
|
||||
"text": transcript,
|
||||
}
|
||||
except Exception as e:
|
||||
print(f'error diariziaing/trascribing {af} - {e}')
|
||||
diameta.write_json(diametadata)
|
||||
|
||||
asr_manifest_writer(out_dir / "manifest.json", data_gen())
|
||||
|
|
|
|||
|
|
@ -13,23 +13,38 @@ class Wav2Vec2TransformersASR(object):
|
|||
"""docstring for Wav2Vec2TransformersASR."""
|
||||
|
||||
def __init__(self, model_dir):
|
||||
super(Wav2Vec2TransformersASR, self).__init__()
|
||||
# super(Wav2Vec2TransformersASR, self).__init__()
|
||||
self.device = "cuda:1" if torch.cuda.is_available() else "cpu"
|
||||
# sd = torch.load(
|
||||
# model_dir / "pytorch_model.bin", map_location=self.device
|
||||
# )
|
||||
# self.processor = Wav2Vec2Processor.from_pretrained(
|
||||
# model_dir, state_dict=sd
|
||||
# )
|
||||
# self.model = Wav2Vec2ForCTC.from_pretrained(model_dir, state_dict=sd).to(self.device)
|
||||
self.processor = Wav2Vec2Processor.from_pretrained(model_dir)
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(model_dir)
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(model_dir).to(self.device)
|
||||
|
||||
def transcribe(self, audio_data):
|
||||
aud_f = BytesIO(audio_data)
|
||||
# net_input = {}
|
||||
speech_data, _ = sf.read(aud_f)
|
||||
input_values = self.processor(
|
||||
speech_data, return_tensors="pt", padding="longest"
|
||||
).input_values # Batch size 1
|
||||
speech_data,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
sampling_rate=16000,
|
||||
).input_values.to(
|
||||
self.device
|
||||
) # Batch size 1
|
||||
|
||||
# retrieve logits
|
||||
#print(f"audio:{speech_data.shape} processed:{input_values.shape}")
|
||||
logits = self.model(input_values).logits
|
||||
|
||||
#print(f"logit shape:{logits.shape}")
|
||||
# take argmax and decode
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
#print(f"predicted_ids shape:{predicted_ids.shape}")
|
||||
transcription = self.processor.batch_decode(predicted_ids)[0]
|
||||
return transcription
|
||||
result = transcription.replace('<s>', '')
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
import logging
|
||||
# import logging
|
||||
from pathlib import Path
|
||||
|
||||
# from rpyc.utils.server import ThreadedServer
|
||||
|
|
@ -10,6 +10,10 @@ from plume.utils import lazy_callable
|
|||
|
||||
# from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR
|
||||
# from .asr import Wav2Vec2ASR
|
||||
# logging.basicConfig(
|
||||
# level=logging.INFO,
|
||||
# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
# )
|
||||
|
||||
ThreadedServer = lazy_callable("rpyc.utils.server.ThreadedServer")
|
||||
Wav2Vec2TransformersASR = lazy_callable(
|
||||
|
|
@ -41,13 +45,12 @@ app = typer.Typer()
|
|||
def rpyc_dir(
|
||||
model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
|
||||
):
|
||||
typer.echo("loading asr model...")
|
||||
w2vasr = Wav2Vec2TransformersASR(model_dir)
|
||||
typer.echo("loaded asr model")
|
||||
service = ASRService(w2vasr)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logging.info("starting asr server...")
|
||||
|
||||
typer.echo(f"serving asr on :{port}...")
|
||||
t = ThreadedServer(service, port=port)
|
||||
t.start()
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,10 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
|
|||
st.title(f"ASR Validation - # {task_uid}")
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
"Go To Sample:",
|
||||
value=sample_no + 1,
|
||||
min_value=1,
|
||||
max_value=len(asr_data),
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
|
|
@ -60,7 +63,9 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
|
|||
show_key(sample, "asr_wer", trail="%")
|
||||
show_key(sample, "correct_candidate")
|
||||
|
||||
st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes())
|
||||
if "plot_path" in sample:
|
||||
st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes())
|
||||
|
||||
st.audio((data_dir / Path(sample["audio_path"])).open("rb"))
|
||||
# set default to text
|
||||
corrected = sample["text"]
|
||||
|
|
@ -78,27 +83,37 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
|
|||
corrected = ""
|
||||
if st.button("Submit"):
|
||||
st.update_entry(
|
||||
sample["utterance_id"], {"status": selected, "correction": corrected}
|
||||
sample["utterance_id"],
|
||||
{"status": selected, "correction": corrected},
|
||||
)
|
||||
st.update_cursor(sample_no + 1)
|
||||
if correction_entry:
|
||||
status = correction_entry["value"]["status"]
|
||||
correction = correction_entry["value"]["correction"]
|
||||
st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
|
||||
st.markdown(
|
||||
f"Your Response: **{status}** Correction: **{correction}**"
|
||||
)
|
||||
text_sample = st.text_input("Go to Text:", value="")
|
||||
if text_sample != "":
|
||||
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
|
||||
candidates = [
|
||||
i for (i, p) in enumerate(asr_data) if p["text"] == text_sample
|
||||
]
|
||||
if len(candidates) > 0:
|
||||
st.update_cursor(candidates[0])
|
||||
real_idx = st.number_input(
|
||||
"Go to real-index",
|
||||
value=sample["real_idx"],
|
||||
min_value=0,
|
||||
max_value=len(asr_data) - 1,
|
||||
)
|
||||
if real_idx != int(sample["real_idx"]):
|
||||
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
|
||||
st.update_cursor(idx)
|
||||
if "real_idx" in sample:
|
||||
real_idx = st.number_input(
|
||||
"Go to real-index",
|
||||
value=sample["real_idx"],
|
||||
min_value=0,
|
||||
max_value=len(asr_data) - 1,
|
||||
)
|
||||
if real_idx != int(sample["real_idx"]):
|
||||
idx = [
|
||||
i
|
||||
for (i, p) in enumerate(asr_data)
|
||||
if p["real_idx"] == real_idx
|
||||
][0]
|
||||
st.update_cursor(idx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -27,7 +27,10 @@ def main(manifest: Path):
|
|||
st.title("ASR Manifest Preview")
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
"Go To Sample:",
|
||||
value=sample_no + 1,
|
||||
min_value=1,
|
||||
max_value=len(asr_data),
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import wave
|
||||
import logging
|
||||
import subprocess
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
from pathlib import Path
|
||||
from plume.utils import lazy_module
|
||||
from plume.utils.audio import audio_seg_to_wav_bytes
|
||||
|
||||
pydub = lazy_module('pydub')
|
||||
torch = lazy_module('torch')
|
||||
|
||||
|
||||
def transform_audio(file_location, path_to_save):
|
||||
audio_seg = (
|
||||
pydub.AudioSegment.from_file(file_location)
|
||||
.set_frame_rate(16000)
|
||||
.set_sample_width(2)
|
||||
)
|
||||
audio_seg.export(path_to_save, format="wav")
|
||||
|
||||
|
||||
def gen_diarizer():
|
||||
pipeline = torch.hub.load("pyannote/pyannote-audio", "dia")
|
||||
|
||||
def _diarizer(audio_path):
|
||||
return pipeline({"audio": audio_path})
|
||||
|
||||
return _diarizer
|
||||
|
||||
|
||||
# base_transcriber, base_prep = transcribe_rpyc_gen()
|
||||
# transcriber, prep = chunk_transcribe_meta_gen(
|
||||
# base_transcriber, base_prep, method="chunked")
|
||||
|
||||
# diarizer = gen_diarizer()
|
||||
|
||||
|
||||
def diarize_audio_gen():
|
||||
diarizer = gen_diarizer()
|
||||
|
||||
def _diarize_audio(audio_path: Path):
|
||||
aseg = (
|
||||
pydub.AudioSegment.from_file(audio_path)
|
||||
.set_frame_rate(16000)
|
||||
.set_sample_width(2)
|
||||
.set_channels(1)
|
||||
)
|
||||
aseg.export("/tmp/temp.wav", format="wav")
|
||||
diarization = diarizer("/tmp/temp.wav")
|
||||
for n, (turn, _, speaker) in enumerate(
|
||||
diarization.itertracks(yield_label=True)
|
||||
):
|
||||
# speaker_label = "Agent" if speaker == "B" else "Customer"
|
||||
turn_seg = aseg[turn.start * 1000 : turn.end * 1000]
|
||||
sample_fname = (
|
||||
audio_path.stem + "_" + str(n) + ".wav"
|
||||
)
|
||||
yield {
|
||||
"speaker": speaker,
|
||||
"wav": audio_seg_to_wav_bytes(turn_seg),
|
||||
"wavseg": turn_seg,
|
||||
"start": turn.start,
|
||||
"end": turn.end,
|
||||
"turnidx": n,
|
||||
"filename": audio_path.name,
|
||||
"sample_fname": sample_fname
|
||||
}
|
||||
|
||||
return _diarize_audio
|
||||
|
|
@ -47,12 +47,15 @@ def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT):
|
|||
asr_seg = (
|
||||
aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||
)
|
||||
# af = BytesIO()
|
||||
# asr_seg.export(af, format="wav")
|
||||
# input_audio_bytes = af.getvalue()
|
||||
return asr_seg
|
||||
|
||||
def dummy_transcript(asr_seg, append_raw=False):
|
||||
af = BytesIO()
|
||||
asr_seg.export(af, format="wav")
|
||||
input_audio_bytes = af.getvalue()
|
||||
return input_audio_bytes
|
||||
|
||||
def dummy_transcript(aud, append_raw=False):
|
||||
aud = af.getvalue()
|
||||
return asr.transcribe(aud)
|
||||
|
||||
return dummy_transcript, audio_prep
|
||||
|
|
@ -147,6 +150,56 @@ def triton_transcribe_grpc_gen(
|
|||
return whole_transcriber, audio_prep
|
||||
|
||||
|
||||
def chunk_transcribe_meta_gen(
|
||||
transcriber,
|
||||
prep,
|
||||
method="chunked",
|
||||
chunk_msec=5000,
|
||||
sil_msec=500,
|
||||
sep=" ",
|
||||
):
|
||||
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
|
||||
import tritonclient.grpc as grpcclient
|
||||
# force loading
|
||||
np.array
|
||||
|
||||
sup_meth = ["chunked", "silence", "whole"]
|
||||
if method not in sup_meth:
|
||||
meths = "|".join(sup_meth)
|
||||
raise Exception(f"unsupported method {method}. pick one of {meths}")
|
||||
|
||||
def chunked_transcriber(aud_seg):
|
||||
if method == "silence":
|
||||
sil_chunks = pydub.silence.split_on_silence(
|
||||
aud_seg,
|
||||
min_silence_len=sil_msec,
|
||||
silence_thresh=-50,
|
||||
keep_silence=500,
|
||||
)
|
||||
chunks = [sc for c in sil_chunks for sc in c[::chunk_msec]]
|
||||
else:
|
||||
chunks = aud_seg[::chunk_msec]
|
||||
# if overlap:
|
||||
# chunks = [
|
||||
# aud_seg[start, end]
|
||||
# for start, end in range(0, int(aud_seg.duration_seconds * 1000, 1000))
|
||||
# ]
|
||||
# pass
|
||||
transcript_list = []
|
||||
sil_pad = pydub.AudioSegment.silent(duration=sil_msec)
|
||||
for seg in chunks:
|
||||
t_seg = sil_pad + seg + sil_pad
|
||||
c_transcript = transcriber(t_seg)
|
||||
transcript_list.append(c_transcript)
|
||||
transcript = sep.join(transcript_list)
|
||||
return transcript
|
||||
whole_transcriber = (
|
||||
transcriber if method == "whole" else chunked_transcriber
|
||||
)
|
||||
|
||||
return whole_transcriber, prep
|
||||
|
||||
|
||||
@app.command()
|
||||
def audio_file(
|
||||
audio_file: Path,
|
||||
|
|
@ -157,13 +210,15 @@ def audio_file(
|
|||
model="slu_num_wav2vec2",
|
||||
):
|
||||
aseg = pydub.AudioSegment.from_file(audio_file)
|
||||
method = "chunked" if chunked else "whole"
|
||||
if rpyc:
|
||||
transcriber, prep = transcribe_rpyc_gen()
|
||||
base_transcriber, base_prep = transcribe_rpyc_gen()
|
||||
else:
|
||||
method = "chunked" if chunked else "whole"
|
||||
transcriber, prep = triton_transcribe_grpc_gen(
|
||||
asr_model=model, method=method, append_raw=append_raw
|
||||
base_transcriber, base_prep = triton_transcribe_grpc_gen(
|
||||
asr_model=model, method='whole', append_raw=append_raw
|
||||
)
|
||||
transcriber, prep = chunk_transcribe_meta_gen(
|
||||
base_transcriber, base_prep, method=method)
|
||||
transcription = transcriber(prep(aseg))
|
||||
|
||||
typer.echo(transcription)
|
||||
|
|
|
|||
|
|
@ -11,10 +11,17 @@ def setup_file_state(st):
|
|||
|
||||
def current_cursor_fn():
|
||||
return task_path.read_json()["current_cursor"]
|
||||
# if "audio_sample_idx" not in st.session_state:
|
||||
# st.session_state.audio_sample_idx = task_path.read_json()[
|
||||
# "current_cursor"
|
||||
# ]
|
||||
# return st.session_state.audio_sample_idx
|
||||
|
||||
def update_cursor_fn(val=0):
|
||||
task_path.write_json({"current_cursor": val})
|
||||
rerun()
|
||||
# rerun()
|
||||
# st.session_state.audio_sample_idx = val
|
||||
st.experimental_rerun()
|
||||
|
||||
st.get_current_cursor = current_cursor_fn
|
||||
st.update_cursor = update_cursor_fn
|
||||
|
|
|
|||
Loading…
Reference in New Issue