1. set flake8 max-line to 79

2. update streamlit dep to 1.0
3. add dev optional dep key
4. implement mono diarized dataset generation script
5. enable gpu support on asr transformers inference pipeline
6. use typer logging
7. clean-up annotation ui with everything other than asr-data keys as optional(including plots)
8. implement chunk_transcribe_meta_gen abstraction for asr chunking logic
9. make ui_persist compatibility change for streamlit 1.0
10. add diarize commands(bugfix)
11. add notebooks for diarization
master
Malar Kannan 2021-10-24 01:32:15 +05:30
parent 846f029cf1
commit 79aa5e8578
13 changed files with 483 additions and 43 deletions

View File

@ -1,4 +1,4 @@
[flake8]
exclude = docs
ignore = E203, W503
max-line-length = 119
max-line-length = 79

225
notebooks/Diarization.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@ -76,7 +76,7 @@ extra_requirements = {
"pymongo~=3.10.1",
"matplotlib~=3.2.1",
"pydub~=0.24.0",
"streamlit~=0.58.0",
"streamlit~=1.0.0",
"natural~=0.2.0",
"stringcase~=1.2.0",
"google-cloud-speech~=1.3.1",
@ -85,6 +85,11 @@ extra_requirements = {
"pyspellchecker~=0.6.2",
"google-cloud-texttospeech~=1.0.1",
"rangehttpserver~=1.2.0",
"streamlit~=1.0.0",
],
"dev": [
"jupyterlab~=3.1.18",
"ipykernel~=6.4.1",
],
"crypto": ["cryptography~=3.4.7"],
"train": ["torchaudio~=0.6.0", "torch-stft~=0.1.4"],

View File

@ -1,11 +1,17 @@
from pathlib import Path
import shutil
from tqdm import tqdm
import typer
from plume.utils.lazy_import import lazy_module
from plume.utils.tts import GoogleTTS
from plume.utils.transcribe import triton_transcribe_grpc_gen
from plume.utils.transcribe import (
triton_transcribe_grpc_gen,
chunk_transcribe_meta_gen,
transcribe_rpyc_gen,
)
from plume.utils.manifest import asr_manifest_writer
from plume.utils.diarize import diarize_audio_gen
from plume.utils.extended_path import ExtendedPath
pydub = lazy_module("pydub")
app = typer.Typer()
@ -50,4 +56,46 @@ def asr_dataset(audio_dir: Path, out_dir: Path, model="slu_num_wav2vec2"):
"text": transcript,
}
asr_manifest_writer(out_dir / 'manifest.json', data_gen())
asr_manifest_writer(out_dir / "manifest.json", data_gen())
@app.command()
def mono_diarize_asr_dataset(audio_dir: Path, out_dir: Path):
out_wav_dir = out_dir / "wavs"
out_wav_dir.mkdir(exist_ok=True, parents=True)
diarize_audio = diarize_audio_gen()
def data_gen():
aud_files = list(audio_dir.glob("*/*.mp3")) + list(
audio_dir.glob("*/*.wav")
)
diameta = ExtendedPath(out_dir / "diameta.json")
base_transcriber, base_prep = transcribe_rpyc_gen()
transcriber, prep = chunk_transcribe_meta_gen(
base_transcriber, base_prep, method="chunked"
)
diametadata = []
for af in tqdm(aud_files):
try:
# raise Exception("Test")
for dres in diarize_audio(af):
sample_fname = dres.pop("sample_fname")
out_af = out_wav_dir / sample_fname
wav_bytes = dres.pop("wav")
out_af.write_bytes(wav_bytes)
audio_af = out_af.relative_to(out_dir)
aud_seg = dres.pop("wavseg")
t_seg = prep(aud_seg)
transcript = transcriber(t_seg)
diametadata.append(dres)
yield {
"audio_filepath": str(audio_af),
"duration": aud_seg.duration_seconds,
"text": transcript,
}
except Exception as e:
print(f'error diariziaing/trascribing {af} - {e}')
diameta.write_json(diametadata)
asr_manifest_writer(out_dir / "manifest.json", data_gen())

View File

View File

@ -13,23 +13,38 @@ class Wav2Vec2TransformersASR(object):
"""docstring for Wav2Vec2TransformersASR."""
def __init__(self, model_dir):
super(Wav2Vec2TransformersASR, self).__init__()
# super(Wav2Vec2TransformersASR, self).__init__()
self.device = "cuda:1" if torch.cuda.is_available() else "cpu"
# sd = torch.load(
# model_dir / "pytorch_model.bin", map_location=self.device
# )
# self.processor = Wav2Vec2Processor.from_pretrained(
# model_dir, state_dict=sd
# )
# self.model = Wav2Vec2ForCTC.from_pretrained(model_dir, state_dict=sd).to(self.device)
self.processor = Wav2Vec2Processor.from_pretrained(model_dir)
self.model = Wav2Vec2ForCTC.from_pretrained(model_dir)
self.model = Wav2Vec2ForCTC.from_pretrained(model_dir).to(self.device)
def transcribe(self, audio_data):
aud_f = BytesIO(audio_data)
# net_input = {}
speech_data, _ = sf.read(aud_f)
input_values = self.processor(
speech_data, return_tensors="pt", padding="longest"
).input_values # Batch size 1
speech_data,
return_tensors="pt",
padding="longest",
sampling_rate=16000,
).input_values.to(
self.device
) # Batch size 1
# retrieve logits
#print(f"audio:{speech_data.shape} processed:{input_values.shape}")
logits = self.model(input_values).logits
#print(f"logit shape:{logits.shape}")
# take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1)
#print(f"predicted_ids shape:{predicted_ids.shape}")
transcription = self.processor.batch_decode(predicted_ids)[0]
return transcription
result = transcription.replace('<s>', '')
return result

View File

@ -1,5 +1,5 @@
import os
import logging
# import logging
from pathlib import Path
# from rpyc.utils.server import ThreadedServer
@ -10,6 +10,10 @@ from plume.utils import lazy_callable
# from plume.models.wav2vec2_transformers.asr import Wav2Vec2TransformersASR
# from .asr import Wav2Vec2ASR
# logging.basicConfig(
# level=logging.INFO,
# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
# )
ThreadedServer = lazy_callable("rpyc.utils.server.ThreadedServer")
Wav2Vec2TransformersASR = lazy_callable(
@ -41,13 +45,12 @@ app = typer.Typer()
def rpyc_dir(
model_dir: Path, port: int = int(os.environ.get("ASR_RPYC_PORT", "8044"))
):
typer.echo("loading asr model...")
w2vasr = Wav2Vec2TransformersASR(model_dir)
typer.echo("loaded asr model")
service = ASRService(w2vasr)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logging.info("starting asr server...")
typer.echo(f"serving asr on :{port}...")
t = ThreadedServer(service, port=port)
t.start()

View File

@ -46,7 +46,10 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
st.title(f"ASR Validation - # {task_uid}")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
"Go To Sample:",
value=sample_no + 1,
min_value=1,
max_value=len(asr_data),
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
@ -60,7 +63,9 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
show_key(sample, "asr_wer", trail="%")
show_key(sample, "correct_candidate")
st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes())
if "plot_path" in sample:
st.sidebar.image((data_dir / Path(sample["plot_path"])).read_bytes())
st.audio((data_dir / Path(sample["audio_path"])).open("rb"))
# set default to text
corrected = sample["text"]
@ -78,27 +83,37 @@ def main(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
corrected = ""
if st.button("Submit"):
st.update_entry(
sample["utterance_id"], {"status": selected, "correction": corrected}
sample["utterance_id"],
{"status": selected, "correction": corrected},
)
st.update_cursor(sample_no + 1)
if correction_entry:
status = correction_entry["value"]["status"]
correction = correction_entry["value"]["correction"]
st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
st.markdown(
f"Your Response: **{status}** Correction: **{correction}**"
)
text_sample = st.text_input("Go to Text:", value="")
if text_sample != "":
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
candidates = [
i for (i, p) in enumerate(asr_data) if p["text"] == text_sample
]
if len(candidates) > 0:
st.update_cursor(candidates[0])
real_idx = st.number_input(
"Go to real-index",
value=sample["real_idx"],
min_value=0,
max_value=len(asr_data) - 1,
)
if real_idx != int(sample["real_idx"]):
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
st.update_cursor(idx)
if "real_idx" in sample:
real_idx = st.number_input(
"Go to real-index",
value=sample["real_idx"],
min_value=0,
max_value=len(asr_data) - 1,
)
if real_idx != int(sample["real_idx"]):
idx = [
i
for (i, p) in enumerate(asr_data)
if p["real_idx"] == real_idx
][0]
st.update_cursor(idx)
if __name__ == "__main__":

View File

@ -27,7 +27,10 @@ def main(manifest: Path):
st.title("ASR Manifest Preview")
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
"Go To Sample:",
value=sample_no + 1,
min_value=1,
max_value=len(asr_data),
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)

View File

@ -1,7 +1,6 @@
import io
import os
import re
import json
import wave
import logging
import subprocess

View File

@ -0,0 +1,65 @@
from pathlib import Path
from plume.utils import lazy_module
from plume.utils.audio import audio_seg_to_wav_bytes
pydub = lazy_module('pydub')
torch = lazy_module('torch')
def transform_audio(file_location, path_to_save):
audio_seg = (
pydub.AudioSegment.from_file(file_location)
.set_frame_rate(16000)
.set_sample_width(2)
)
audio_seg.export(path_to_save, format="wav")
def gen_diarizer():
pipeline = torch.hub.load("pyannote/pyannote-audio", "dia")
def _diarizer(audio_path):
return pipeline({"audio": audio_path})
return _diarizer
# base_transcriber, base_prep = transcribe_rpyc_gen()
# transcriber, prep = chunk_transcribe_meta_gen(
# base_transcriber, base_prep, method="chunked")
# diarizer = gen_diarizer()
def diarize_audio_gen():
diarizer = gen_diarizer()
def _diarize_audio(audio_path: Path):
aseg = (
pydub.AudioSegment.from_file(audio_path)
.set_frame_rate(16000)
.set_sample_width(2)
.set_channels(1)
)
aseg.export("/tmp/temp.wav", format="wav")
diarization = diarizer("/tmp/temp.wav")
for n, (turn, _, speaker) in enumerate(
diarization.itertracks(yield_label=True)
):
# speaker_label = "Agent" if speaker == "B" else "Customer"
turn_seg = aseg[turn.start * 1000 : turn.end * 1000]
sample_fname = (
audio_path.stem + "_" + str(n) + ".wav"
)
yield {
"speaker": speaker,
"wav": audio_seg_to_wav_bytes(turn_seg),
"wavseg": turn_seg,
"start": turn.start,
"end": turn.end,
"turnidx": n,
"filename": audio_path.name,
"sample_fname": sample_fname
}
return _diarize_audio

View File

@ -47,12 +47,15 @@ def transcribe_rpyc_gen(asr_host=ASR_RPYC_HOST, asr_port=ASR_RPYC_PORT):
asr_seg = (
aud_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
)
# af = BytesIO()
# asr_seg.export(af, format="wav")
# input_audio_bytes = af.getvalue()
return asr_seg
def dummy_transcript(asr_seg, append_raw=False):
af = BytesIO()
asr_seg.export(af, format="wav")
input_audio_bytes = af.getvalue()
return input_audio_bytes
def dummy_transcript(aud, append_raw=False):
aud = af.getvalue()
return asr.transcribe(aud)
return dummy_transcript, audio_prep
@ -147,6 +150,56 @@ def triton_transcribe_grpc_gen(
return whole_transcriber, audio_prep
def chunk_transcribe_meta_gen(
transcriber,
prep,
method="chunked",
chunk_msec=5000,
sil_msec=500,
sep=" ",
):
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
import tritonclient.grpc as grpcclient
# force loading
np.array
sup_meth = ["chunked", "silence", "whole"]
if method not in sup_meth:
meths = "|".join(sup_meth)
raise Exception(f"unsupported method {method}. pick one of {meths}")
def chunked_transcriber(aud_seg):
if method == "silence":
sil_chunks = pydub.silence.split_on_silence(
aud_seg,
min_silence_len=sil_msec,
silence_thresh=-50,
keep_silence=500,
)
chunks = [sc for c in sil_chunks for sc in c[::chunk_msec]]
else:
chunks = aud_seg[::chunk_msec]
# if overlap:
# chunks = [
# aud_seg[start, end]
# for start, end in range(0, int(aud_seg.duration_seconds * 1000, 1000))
# ]
# pass
transcript_list = []
sil_pad = pydub.AudioSegment.silent(duration=sil_msec)
for seg in chunks:
t_seg = sil_pad + seg + sil_pad
c_transcript = transcriber(t_seg)
transcript_list.append(c_transcript)
transcript = sep.join(transcript_list)
return transcript
whole_transcriber = (
transcriber if method == "whole" else chunked_transcriber
)
return whole_transcriber, prep
@app.command()
def audio_file(
audio_file: Path,
@ -157,13 +210,15 @@ def audio_file(
model="slu_num_wav2vec2",
):
aseg = pydub.AudioSegment.from_file(audio_file)
method = "chunked" if chunked else "whole"
if rpyc:
transcriber, prep = transcribe_rpyc_gen()
base_transcriber, base_prep = transcribe_rpyc_gen()
else:
method = "chunked" if chunked else "whole"
transcriber, prep = triton_transcribe_grpc_gen(
asr_model=model, method=method, append_raw=append_raw
base_transcriber, base_prep = triton_transcribe_grpc_gen(
asr_model=model, method='whole', append_raw=append_raw
)
transcriber, prep = chunk_transcribe_meta_gen(
base_transcriber, base_prep, method=method)
transcription = transcriber(prep(aseg))
typer.echo(transcription)

View File

@ -11,10 +11,17 @@ def setup_file_state(st):
def current_cursor_fn():
return task_path.read_json()["current_cursor"]
# if "audio_sample_idx" not in st.session_state:
# st.session_state.audio_sample_idx = task_path.read_json()[
# "current_cursor"
# ]
# return st.session_state.audio_sample_idx
def update_cursor_fn(val=0):
task_path.write_json({"current_cursor": val})
rerun()
# rerun()
# st.session_state.audio_sample_idx = val
st.experimental_rerun()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn