From 076b0d11e3dd83accea4516a4a0b5eeb75168604 Mon Sep 17 00:00:00 2001 From: Malar Date: Mon, 19 Jul 2021 15:20:50 +0530 Subject: [PATCH] 1. add some pyaudio dep 2. fixed merge/ add eject command with unlink option 3. wip - marblenet vad 4. add slu_infer ui util 5. fix filter command with maxmin support 6. some logging changes and fixes --- setup.py | 1 + src/plume/cli/data/__init__.py | 33 ++- src/plume/models/marblenet_nemo/trial.py | 347 ++++++++++++++++++++++- src/plume/models/wav2vec2/data.py | 6 +- src/plume/ui/__init__.py | 11 +- src/plume/ui/slu_infer.py | 32 +++ src/plume/utils/__init__.py | 39 +-- src/plume/utils/encrypt.py | 2 +- src/plume/utils/manifest.py | 11 + src/plume/utils/transcribe.py | 4 +- 10 files changed, 452 insertions(+), 34 deletions(-) create mode 100644 src/plume/ui/slu_infer.py diff --git a/setup.py b/setup.py index ec7b585..bc024e6 100644 --- a/setup.py +++ b/setup.py @@ -66,6 +66,7 @@ extra_requirements = { "pyspellchecker~=0.6.2", "num2words~=0.5.10", "pydub~=0.24.0", + "pyaudio~=0.2.11" ], "infer_min": [ "pyspellchecker~=0.6.2", diff --git a/src/plume/cli/data/__init__.py b/src/plume/cli/data/__init__.py index e8de6a1..9ceb40d 100644 --- a/src/plume/cli/data/__init__.py +++ b/src/plume/cli/data/__init__.py @@ -73,7 +73,12 @@ def fix_path(dataset_path: Path, force: bool = False): @app.command() -def merge(src_dataset_paths: List[Path], dest_dataset_path: Path): +def merge( + src_dataset_paths: List[Path], + dest_dataset_path: Path, + unlink: bool = False, + verbose: bool = True, +): reader_list = [] abs_manifest_path = Path("abs_manifest.json") for dataset_path in src_dataset_paths: @@ -81,7 +86,29 @@ def merge(src_dataset_paths: List[Path], dest_dataset_path: Path): reader_list.append(asr_manifest_reader(manifest_path)) dest_dataset_path.mkdir(parents=True, exist_ok=True) dest_manifest_path = dest_dataset_path / abs_manifest_path - asr_manifest_writer(dest_manifest_path, chain(*reader_list)) + asr_manifest_writer( + dest_manifest_path, chain(*reader_list), verbose=verbose + ) + if unlink: + eject(dest_dataset_path, verbose=verbose) + + +def eject(dest_dataset_path: Path, verbose: bool = False): + wav_dir = dest_dataset_path / Path("wavs") + wav_dir.mkdir(exist_ok=True, parents=True) + abs_manifest_path = ExtendedPath( + dest_dataset_path / Path("abs_manifest.json") + ) + backup_abs_manifest_path = abs_manifest_path.with_suffix(".json.orig") + shutil.copy(abs_manifest_path, backup_abs_manifest_path) + manifest_data = list(abs_manifest_path.read_jsonl()) + for md in tqdm(manifest_data) if verbose else manifest_data: + orig_path = Path(md["audio_filepath"]) + new_path = wav_dir / Path(orig_path.name) + shutil.copy(orig_path, new_path) + md["audio_filepath"] = str(new_path) + abs_manifest_path.write_jsonl(manifest_data) + fix_path(dest_dataset_path) @app.command() @@ -275,7 +302,7 @@ def encrypt( src_dataset_path: Path, dest_dataset_path: Path, encryption_key: str = typer.Option(..., prompt=True, hide_input=True), - verbose: bool = False, + verbose: bool = True, ): dest_manifest = dest_dataset_path / Path("manifest.json") src_manifest = src_dataset_path / Path("manifest.json") diff --git a/src/plume/models/marblenet_nemo/trial.py b/src/plume/models/marblenet_nemo/trial.py index 2a21ddb..a0c49a5 100644 --- a/src/plume/models/marblenet_nemo/trial.py +++ b/src/plume/models/marblenet_nemo/trial.py @@ -1,22 +1,357 @@ import numpy as np import os -import time -import copy -from omegaconf import OmegaConf +# import time +import copy +import wave +import wget + +# from omegaconf import OmegaConf + import matplotlib.pyplot as plt +import librosa.display + import IPython.display as ipd + # import pyaudio as pa import librosa -import nemo + +# import nemo import nemo.collections.asr as nemo_asr +from nemo.core.classes import IterableDataset +from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType +import torch +from torch.utils.data import DataLoader # sample rate, Hz SAMPLE_RATE = 16000 +# import pdb; pdb.set_trace() -vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained( - "vad_marblenet" +# vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained( +# "vad_marblenet" +# ) +# vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained( +# model_name="MarbleNet-3x2x64-Telephony" +# ) +vad_model = nemo_asr.models.EncDecClassificationModel.restore_from( + "/home/malar/work/test/vad_telephony_marblenet.nemo" ) +# vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained( +# model_name="vad_telephony_marblenet" +# ) # Preserve a copy of the full config cfg = copy.deepcopy(vad_model._cfg) # print(OmegaConf.to_yaml(cfg)) + +vad_model.preprocessor = vad_model.from_config_dict(cfg.preprocessor) + +# Set model to inference mode +vad_model.eval() +vad_model = vad_model.to(vad_model.device) +# import pdb; pdb.set_trace() + +# simple data layer to pass audio signal +class AudioDataLayer(IterableDataset): + @property + def output_types(self): + return { + "audio_signal": NeuralType( + ("B", "T"), AudioSignal(freq=self._sample_rate) + ), + "a_sig_length": NeuralType(tuple("B"), LengthsType()), + } + + def __init__(self, sample_rate): + super().__init__() + self._sample_rate = sample_rate + self.output = True + + def __iter__(self): + return self + + def __next__(self): + if not self.output: + raise StopIteration + self.output = False + return ( + torch.as_tensor(self.signal, dtype=torch.float32), + torch.as_tensor(self.signal_shape, dtype=torch.int64), + ) + + def set_signal(self, signal): + self.signal = signal.astype(np.float32) / 32768.0 + self.signal_shape = self.signal.size + self.output = True + + def __len__(self): + return 1 + + +data_layer = AudioDataLayer(sample_rate=cfg.train_ds.sample_rate) +data_loader = DataLoader( + data_layer, batch_size=1, collate_fn=data_layer.collate_fn +) + + +# inference method for audio signal (single instance) +def infer_signal(model, signal): + data_layer.set_signal(signal) + batch = next(iter(data_loader)) + audio_signal, audio_signal_len = batch + audio_signal, audio_signal_len = ( + audio_signal.to(vad_model.device), + audio_signal_len.to(vad_model.device), + ) + logits = model.forward( + input_signal=audio_signal, input_signal_length=audio_signal_len + ) + return logits + + +# class for streaming frame-based VAD +# 1) use reset() method to reset FrameVAD's state +# 2) call transcribe(frame) to do VAD on +# contiguous signal's frames +class FrameVAD: + def __init__( + self, + model_definition, + threshold=0.5, + frame_len=2, + frame_overlap=2.5, + offset=10, + ): + """ + Args: + threshold: If prob of speech is larger than threshold, classify the segment to be speech. + frame_len: frame's duration, seconds + frame_overlap: duration of overlaps before and after current frame, seconds + offset: number of symbols to drop for smooth streaming + """ + self.vocab = list(model_definition["labels"]) + self.vocab.append("_") + + self.sr = model_definition["sample_rate"] + self.threshold = threshold + self.frame_len = frame_len + self.n_frame_len = int(frame_len * self.sr) + self.frame_overlap = frame_overlap + self.n_frame_overlap = int(frame_overlap * self.sr) + timestep_duration = model_definition["AudioToMFCCPreprocessor"][ + "window_stride" + ] + for block in model_definition["JasperEncoder"]["jasper"]: + timestep_duration *= block["stride"][0] ** block["repeat"] + self.buffer = np.zeros( + shape=2 * self.n_frame_overlap + self.n_frame_len, dtype=np.float32 + ) + self.offset = offset + self.reset() + + def _decode(self, frame, offset=0): + assert len(frame) == self.n_frame_len + self.buffer[: -self.n_frame_len] = self.buffer[self.n_frame_len :] + self.buffer[-self.n_frame_len :] = frame + logits = infer_signal(vad_model, self.buffer).cpu().numpy()[0] + decoded = self._greedy_decoder(self.threshold, logits, self.vocab) + return decoded + + @torch.no_grad() + def transcribe(self, frame=None): + if frame is None: + frame = np.zeros(shape=self.n_frame_len, dtype=np.float32) + if len(frame) < self.n_frame_len: + frame = np.pad( + frame, [0, self.n_frame_len - len(frame)], "constant" + ) + unmerged = self._decode(frame, self.offset) + return unmerged + + def reset(self): + """ + Reset frame_history and decoder's state + """ + self.buffer = np.zeros(shape=self.buffer.shape, dtype=np.float32) + self.prev_char = "" + + @staticmethod + def _greedy_decoder(threshold, logits, vocab): + s = [] + if logits.shape[0]: + probs = torch.softmax(torch.as_tensor(logits), dim=-1) + probas, _ = torch.max(probs, dim=-1) + probas_s = probs[1].item() + preds = 1 if probas_s >= threshold else 0 + s = [ + preds, + str(vocab[preds]), + probs[0].item(), + probs[1].item(), + str(logits), + ] + return s + + +# WINDOW_SIZE_RANGE = [0.10, 0.15, 0.20, 0.25, 0.30, 0.5, 0.8] +# # STEP_RANGE = [0.01, 0.02, 0.03] +# # WINDOW_SIZE_RANGE = [0.15, 0.20] +# STEP_RANGE = [0.01, 0.02, 0.03] +WINDOW_SIZE_RANGE = [0.15, 0.20, 0.25] +# STEP_RANGE = [0.01, 0.02, 0.03] +# WINDOW_SIZE_RANGE = [0.15, 0.20] +STEP_RANGE = [0.03, 0.05, 0.07, 0.1] +STEP_LIST = [r for t in STEP_RANGE for r in [t]*len(WINDOW_SIZE_RANGE)] +# STEP_LIST +# STEP_LIST = ( +# [0.01] * len(WINDOW_SIZE_RANGE) +# + [0.02] * len(WINDOW_SIZE_RANGE) +# + [0.03] * len(WINDOW_SIZE_RANGE) +# ) +WINDOW_SIZE_LIST = WINDOW_SIZE_RANGE * len(STEP_RANGE) + + +def offline_inference(wave_file, STEP=0.025, WINDOW_SIZE=0.5, threshold=0.5): + + FRAME_LEN = STEP # infer every STEP seconds + CHANNELS = 1 # number of audio channels (expect mono signal) + RATE = 16000 # sample rate, Hz + + CHUNK_SIZE = int(FRAME_LEN * RATE) + + vad = FrameVAD( + model_definition={ + "sample_rate": SAMPLE_RATE, + "AudioToMFCCPreprocessor": cfg.preprocessor, + "JasperEncoder": cfg.encoder, + "labels": cfg.labels, + }, + threshold=threshold, + frame_len=FRAME_LEN, + frame_overlap=(WINDOW_SIZE - FRAME_LEN) / 2, + offset=0, + ) + + wf = wave.open(wave_file, "rb") + # p = pa.PyAudio() + + empty_counter = 0 + + preds = [] + proba_b = [] + proba_s = [] + + # stream = p.open( + # format=p.get_format_from_width(wf.getsampwidth()), + # channels=CHANNELS, + # rate=RATE, + # output=True, + # ) + + data = wf.readframes(CHUNK_SIZE) + + while len(data) > 0: + + data = wf.readframes(CHUNK_SIZE) + signal = np.frombuffer(data, dtype=np.int16) + result = vad.transcribe(signal) + + preds.append(result[0]) + proba_b.append(result[2]) + proba_s.append(result[3]) + + if len(result): + # print(result, end="\n") + empty_counter = 3 + elif empty_counter > 0: + empty_counter -= 1 + # if empty_counter == 0: + # print(" ", end="") + + # p.terminate() + vad.reset() + + return preds, proba_b, proba_s + + +# demo_wave = "VAD_demo.wav" +# if not os.path.exists(demo_wave): +# wget.download( +# "https://dldata-public.s3.us-east-2.amazonaws.com/VAD_demo.wav", +# demo_wave, +# ) + +demo_wave = "WAL-1201-cust.wav" + + +wave_file = demo_wave + +CHANNELS = 1 +RATE = 16000 +audio, sample_rate = librosa.load(wave_file, sr=RATE) +dur = librosa.get_duration(audio) +print(dur) + + +threshold = 0.5 + +results = [] +for STEP, WINDOW_SIZE in zip( + STEP_LIST, + WINDOW_SIZE_LIST, +): + print(f"====== STEP is {STEP}s, WINDOW_SIZE is {WINDOW_SIZE}s ====== ") + preds, proba_b, proba_s = offline_inference( + wave_file, STEP, WINDOW_SIZE, threshold + ) + results.append([STEP, WINDOW_SIZE, preds, proba_b, proba_s]) + + +plt.figure(figsize=[20, 3*len(STEP_LIST)]) + +num = len(results) +for i in range(num): + len_pred = len(results[i][2]) + FRAME_LEN = results[i][0] + ax1 = plt.subplot(num + 1, 1, i + 1) + + ax1.plot(np.arange(audio.size) / sample_rate, audio, "b") + ax1.set_xlim([-0.01, int(dur) + 1]) + ax1.tick_params(axis="y", labelcolor="b") + ax1.set_ylabel("Signal") + ax1.set_ylim([-1, 1]) + + proba_s = results[i][4] + pred = [1 if p > threshold else 0 for p in proba_s] + ax2 = ax1.twinx() + ax2.plot( + np.arange(len_pred) / (1 / results[i][0]), + np.array(pred), + "r", + label="pred", + ) + ax2.plot( + np.arange(len_pred) / (1 / results[i][0]), + np.array(proba_s), + "g--", + label="speech prob", + ) + ax2.tick_params(axis="y", labelcolor="r") + legend = ax2.legend(loc="lower right", shadow=True) + ax1.set_ylabel("prediction") + + ax2.set_title(f"step {results[i][0]}s, buffer size {results[i][1]}s") + ax2.set_ylabel("Preds and Probas") + + +ax = plt.subplot(num + 1, 1, i + 2) +S = librosa.feature.melspectrogram( + y=audio, sr=sample_rate, n_mels=64, fmax=8000 +) +S_dB = librosa.power_to_db(S, ref=np.max) +librosa.display.specshow( + S_dB, x_axis="time", y_axis="mel", sr=sample_rate, fmax=8000 +) +ax.set_title("Mel-frequency spectrogram") +ax.grid() +plt.show() +ipd.Audio(data=audio, rate=sample_rate) diff --git a/src/plume/models/wav2vec2/data.py b/src/plume/models/wav2vec2/data.py index f6a3fe9..c67da1f 100644 --- a/src/plume/models/wav2vec2/data.py +++ b/src/plume/models/wav2vec2/data.py @@ -55,10 +55,8 @@ def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool out_tsv = dest_dataset_path / Path(o_tsv) out_ltr = dest_dataset_path / Path(o_ltr) with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f: - if unlink: - tsv_f.write(f"{dest_dataset_path}\n") - else: - tsv_f.write(f"{src_dataset_path}\n") + dest_path = dest_dataset_path if unlink else src_dataset_path + tsv_f.write(f"{dest_path}\n") for md in manifest_data: audio_fname = md["audio_filepath"] pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper() diff --git a/src/plume/ui/__init__.py b/src/plume/ui/__init__.py index d8ab1e4..3af0f16 100644 --- a/src/plume/ui/__init__.py +++ b/src/plume/ui/__init__.py @@ -18,7 +18,9 @@ def ui(): @app.command() -def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""): +def annotation( + data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = "" +): annotation_lit_path = Path(__file__).parent / Path("annotation.py") if task_id: sys.argv = [ @@ -83,6 +85,13 @@ def audio(audio_dir: Path): sys.exit(stcli.main()) +@app.command() +def slu_infer(): + lit_path = Path(__file__).parent / Path("slu_infer.py") + sys.argv = ["streamlit", "run", str(lit_path)] + sys.exit(stcli.main()) + + @app.command() def collection(data_dir: Path, task_id: str = ""): # TODO: Implement web ui for data collection diff --git a/src/plume/ui/slu_infer.py b/src/plume/ui/slu_infer.py new file mode 100644 index 0000000..7b4f9ab --- /dev/null +++ b/src/plume/ui/slu_infer.py @@ -0,0 +1,32 @@ +# from pathlib import Path + +import streamlit as st +import typer + +from plume.utils.transcribe import triton_transcribe_grpc_gen +from plume.utils.audio import audio_wav_bytes_to_seg + +app = typer.Typer() + +transcriber, prep = triton_transcribe_grpc_gen( + asr_model="slu_num_wav2vec2", method="whole", append_raw=True +) + + +@app.command() +def main(): + st.title("SLU Inference") + audio_file = st.file_uploader("Upload File", type=["wav", "mp3"]) + if audio_file: + audio_bytes = audio_file.read() + seg = audio_wav_bytes_to_seg(audio_bytes) + st.audio(audio_bytes) + tscript = transcriber(prep(seg)) + st.write(tscript) + + +if __name__ == "__main__": + try: + app() + except SystemExit: + pass diff --git a/src/plume/utils/__init__.py b/src/plume/utils/__init__.py index 9bfdb64..87b2237 100644 --- a/src/plume/utils/__init__.py +++ b/src/plume/utils/__init__.py @@ -32,7 +32,11 @@ import six # from .transcribe import triton_transcribe_grpc_gen # from .eval import app as eval_app -from .manifest import asr_manifest_writer, manifest_str +from .manifest import ( + asr_manifest_writer, + asr_manifest_reader, + manifest_str, +) # noqa from .lazy_import import lazy_callable, lazy_module from .parallel import parallel_apply from .extended_path import ExtendedPath @@ -430,17 +434,6 @@ def ui_dump_manifest_writer(dataset_dir, asr_data_source, verbose=False): return num_datapoints -def asr_manifest_reader(data_manifest_path: Path): - print(f"reading manifest from {data_manifest_path}") - with data_manifest_path.open("r") as pf: - data_jsonl = pf.readlines() - data_data = [json.loads(v) for v in data_jsonl] - for p in data_data: - p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"]) - p["text"] = p["text"].strip() - yield p - - def asr_test_writer(out_file_path: Path, source): def dd_str(dd, idx): path = dd["audio_filepath"] @@ -558,11 +551,19 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file): blank_count += 1 typer.echo(f"filtered {blank_count} of {total_count} blank samples") - def filtered_max_sample_dur(): + def filtered_maxmin_sample_dur(): + import soundfile + max_dur_count = 0 for s in src_data_enum: - wav_duration = s["duration"] - if wav_duration <= max_sample_dur: + wav_real_duration = soundfile.info( + src_dataset_path / Path(s["audio_filepath"]) + ).duration + wav_duration = min(wav_real_duration, s["duration"]) + if ( + wav_duration <= max_sample_dur + and wav_duration > min_sample_dur + ): shutil.copy( src_dataset_path / Path(s["audio_filepath"]), dest_dataset_path / Path(s["audio_filepath"]), @@ -571,7 +572,7 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file): else: max_dur_count += 1 typer.echo( - f"filtered {max_dur_count} samples longer thans {max_sample_dur}s" + f"filtered {max_dur_count} samples longer thans {max_sample_dur}s and shorter than {min_sample_dur}s" ) def filtered_transform_digits(): @@ -641,7 +642,9 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file): wav_duration = 0 for s in src_data_enum: # nums = re.sub(" ", "", s["text"]) - s["text"] = "gAAAAABgq2FR6ajbhMsDmWRQBzX6gIzyAG5sMwFihGeV7E_6eVJqqF78yzmtTJPsJAOJEEXhJ9Z45MrYNgE1sq7VUdsBVGh2cw==" + s[ + "text" + ] = "gAAAAABgq2FR6ajbhMsDmWRQBzX6gIzyAG5sMwFihGeV7E_6eVJqqF78yzmtTJPsJAOJEEXhJ9Z45MrYNgE1sq7VUdsBVGh2cw==" if ( s["duration"] >= min_sample_dur and s["duration"] <= max_sample_dur @@ -663,7 +666,7 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file): "transform_digits": filtered_transform_digits, "extract_chars": filtered_extract_chars, "resample_ulaw24kmono": filtered_resample, - "max_sample_dur": filtered_max_sample_dur, + "maxmin_sample_dur": filtered_maxmin_sample_dur, "msec_to_sec": filtered_msec_to_sec, "blank_3hr_max_dur": filtered_blank_hr_max_dur, } diff --git a/src/plume/utils/encrypt.py b/src/plume/utils/encrypt.py index cd84fc3..e3f84f2 100644 --- a/src/plume/utils/encrypt.py +++ b/src/plume/utils/encrypt.py @@ -13,7 +13,7 @@ from .audio import audio_seg_to_wav_bytes, audio_wav_bytes_to_seg from .parallel import parallel_apply from .lazy_import import lazy_module -cryptography = lazy_module("cryptography") +cryptography = lazy_module("cryptography.fernet", level='base') # cryptography.fernet = lazy_module("cryptography.fernet") pydub = lazy_module("pydub") diff --git a/src/plume/utils/manifest.py b/src/plume/utils/manifest.py index ee7a475..c1c98dc 100644 --- a/src/plume/utils/manifest.py +++ b/src/plume/utils/manifest.py @@ -13,6 +13,17 @@ def manifest_str(path, dur, text): return json.dumps(k) + "\n" +def asr_manifest_reader(data_manifest_path: Path): + print(f"reading manifest from {data_manifest_path}") + with data_manifest_path.open("r") as pf: + data_jsonl = pf.readlines() + data_data = [json.loads(v) for v in data_jsonl] + for p in data_data: + p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"]) + p["text"] = p["text"].strip() + yield p + + def asr_manifest_writer( asr_manifest_path: Path, manifest_str_source, verbose=False ): diff --git a/src/plume/utils/transcribe.py b/src/plume/utils/transcribe.py index bb02f24..a4661bc 100644 --- a/src/plume/utils/transcribe.py +++ b/src/plume/utils/transcribe.py @@ -104,6 +104,8 @@ def triton_transcribe_grpc_gen( if len(outputs) > 1 and append_raw: transcript = transcript + "|" + outputs[1].decode("utf-8") except InferenceServerException: + import traceback + traceback.print_exc() transcript = "[server error]" return transcript @@ -146,7 +148,7 @@ def triton_transcribe_grpc_gen( @app.command() -def file( +def audio_file( audio_file: Path, write_file: bool = False, chunked: bool = False,