1. add some pyaudio dep

2. fixed merge/ add eject command with unlink option
3. wip - marblenet vad
4. add slu_infer ui util
5. fix filter command with maxmin support
6. some logging changes and fixes
tegra
Malar 2021-07-19 15:20:50 +05:30
parent 4bca2097e1
commit 076b0d11e3
10 changed files with 452 additions and 34 deletions

View File

@ -66,6 +66,7 @@ extra_requirements = {
"pyspellchecker~=0.6.2",
"num2words~=0.5.10",
"pydub~=0.24.0",
"pyaudio~=0.2.11"
],
"infer_min": [
"pyspellchecker~=0.6.2",

View File

@ -73,7 +73,12 @@ def fix_path(dataset_path: Path, force: bool = False):
@app.command()
def merge(src_dataset_paths: List[Path], dest_dataset_path: Path):
def merge(
src_dataset_paths: List[Path],
dest_dataset_path: Path,
unlink: bool = False,
verbose: bool = True,
):
reader_list = []
abs_manifest_path = Path("abs_manifest.json")
for dataset_path in src_dataset_paths:
@ -81,7 +86,29 @@ def merge(src_dataset_paths: List[Path], dest_dataset_path: Path):
reader_list.append(asr_manifest_reader(manifest_path))
dest_dataset_path.mkdir(parents=True, exist_ok=True)
dest_manifest_path = dest_dataset_path / abs_manifest_path
asr_manifest_writer(dest_manifest_path, chain(*reader_list))
asr_manifest_writer(
dest_manifest_path, chain(*reader_list), verbose=verbose
)
if unlink:
eject(dest_dataset_path, verbose=verbose)
def eject(dest_dataset_path: Path, verbose: bool = False):
wav_dir = dest_dataset_path / Path("wavs")
wav_dir.mkdir(exist_ok=True, parents=True)
abs_manifest_path = ExtendedPath(
dest_dataset_path / Path("abs_manifest.json")
)
backup_abs_manifest_path = abs_manifest_path.with_suffix(".json.orig")
shutil.copy(abs_manifest_path, backup_abs_manifest_path)
manifest_data = list(abs_manifest_path.read_jsonl())
for md in tqdm(manifest_data) if verbose else manifest_data:
orig_path = Path(md["audio_filepath"])
new_path = wav_dir / Path(orig_path.name)
shutil.copy(orig_path, new_path)
md["audio_filepath"] = str(new_path)
abs_manifest_path.write_jsonl(manifest_data)
fix_path(dest_dataset_path)
@app.command()
@ -275,7 +302,7 @@ def encrypt(
src_dataset_path: Path,
dest_dataset_path: Path,
encryption_key: str = typer.Option(..., prompt=True, hide_input=True),
verbose: bool = False,
verbose: bool = True,
):
dest_manifest = dest_dataset_path / Path("manifest.json")
src_manifest = src_dataset_path / Path("manifest.json")

View File

@ -1,22 +1,357 @@
import numpy as np
import os
import time
import copy
from omegaconf import OmegaConf
# import time
import copy
import wave
import wget
# from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import librosa.display
import IPython.display as ipd
# import pyaudio as pa
import librosa
import nemo
# import nemo
import nemo.collections.asr as nemo_asr
from nemo.core.classes import IterableDataset
from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType
import torch
from torch.utils.data import DataLoader
# sample rate, Hz
SAMPLE_RATE = 16000
# import pdb; pdb.set_trace()
vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained(
"vad_marblenet"
# vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained(
# "vad_marblenet"
# )
# vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained(
# model_name="MarbleNet-3x2x64-Telephony"
# )
vad_model = nemo_asr.models.EncDecClassificationModel.restore_from(
"/home/malar/work/test/vad_telephony_marblenet.nemo"
)
# vad_model = nemo_asr.models.EncDecClassificationModel.from_pretrained(
# model_name="vad_telephony_marblenet"
# )
# Preserve a copy of the full config
cfg = copy.deepcopy(vad_model._cfg)
# print(OmegaConf.to_yaml(cfg))
vad_model.preprocessor = vad_model.from_config_dict(cfg.preprocessor)
# Set model to inference mode
vad_model.eval()
vad_model = vad_model.to(vad_model.device)
# import pdb; pdb.set_trace()
# simple data layer to pass audio signal
class AudioDataLayer(IterableDataset):
@property
def output_types(self):
return {
"audio_signal": NeuralType(
("B", "T"), AudioSignal(freq=self._sample_rate)
),
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
}
def __init__(self, sample_rate):
super().__init__()
self._sample_rate = sample_rate
self.output = True
def __iter__(self):
return self
def __next__(self):
if not self.output:
raise StopIteration
self.output = False
return (
torch.as_tensor(self.signal, dtype=torch.float32),
torch.as_tensor(self.signal_shape, dtype=torch.int64),
)
def set_signal(self, signal):
self.signal = signal.astype(np.float32) / 32768.0
self.signal_shape = self.signal.size
self.output = True
def __len__(self):
return 1
data_layer = AudioDataLayer(sample_rate=cfg.train_ds.sample_rate)
data_loader = DataLoader(
data_layer, batch_size=1, collate_fn=data_layer.collate_fn
)
# inference method for audio signal (single instance)
def infer_signal(model, signal):
data_layer.set_signal(signal)
batch = next(iter(data_loader))
audio_signal, audio_signal_len = batch
audio_signal, audio_signal_len = (
audio_signal.to(vad_model.device),
audio_signal_len.to(vad_model.device),
)
logits = model.forward(
input_signal=audio_signal, input_signal_length=audio_signal_len
)
return logits
# class for streaming frame-based VAD
# 1) use reset() method to reset FrameVAD's state
# 2) call transcribe(frame) to do VAD on
# contiguous signal's frames
class FrameVAD:
def __init__(
self,
model_definition,
threshold=0.5,
frame_len=2,
frame_overlap=2.5,
offset=10,
):
"""
Args:
threshold: If prob of speech is larger than threshold, classify the segment to be speech.
frame_len: frame's duration, seconds
frame_overlap: duration of overlaps before and after current frame, seconds
offset: number of symbols to drop for smooth streaming
"""
self.vocab = list(model_definition["labels"])
self.vocab.append("_")
self.sr = model_definition["sample_rate"]
self.threshold = threshold
self.frame_len = frame_len
self.n_frame_len = int(frame_len * self.sr)
self.frame_overlap = frame_overlap
self.n_frame_overlap = int(frame_overlap * self.sr)
timestep_duration = model_definition["AudioToMFCCPreprocessor"][
"window_stride"
]
for block in model_definition["JasperEncoder"]["jasper"]:
timestep_duration *= block["stride"][0] ** block["repeat"]
self.buffer = np.zeros(
shape=2 * self.n_frame_overlap + self.n_frame_len, dtype=np.float32
)
self.offset = offset
self.reset()
def _decode(self, frame, offset=0):
assert len(frame) == self.n_frame_len
self.buffer[: -self.n_frame_len] = self.buffer[self.n_frame_len :]
self.buffer[-self.n_frame_len :] = frame
logits = infer_signal(vad_model, self.buffer).cpu().numpy()[0]
decoded = self._greedy_decoder(self.threshold, logits, self.vocab)
return decoded
@torch.no_grad()
def transcribe(self, frame=None):
if frame is None:
frame = np.zeros(shape=self.n_frame_len, dtype=np.float32)
if len(frame) < self.n_frame_len:
frame = np.pad(
frame, [0, self.n_frame_len - len(frame)], "constant"
)
unmerged = self._decode(frame, self.offset)
return unmerged
def reset(self):
"""
Reset frame_history and decoder's state
"""
self.buffer = np.zeros(shape=self.buffer.shape, dtype=np.float32)
self.prev_char = ""
@staticmethod
def _greedy_decoder(threshold, logits, vocab):
s = []
if logits.shape[0]:
probs = torch.softmax(torch.as_tensor(logits), dim=-1)
probas, _ = torch.max(probs, dim=-1)
probas_s = probs[1].item()
preds = 1 if probas_s >= threshold else 0
s = [
preds,
str(vocab[preds]),
probs[0].item(),
probs[1].item(),
str(logits),
]
return s
# WINDOW_SIZE_RANGE = [0.10, 0.15, 0.20, 0.25, 0.30, 0.5, 0.8]
# # STEP_RANGE = [0.01, 0.02, 0.03]
# # WINDOW_SIZE_RANGE = [0.15, 0.20]
# STEP_RANGE = [0.01, 0.02, 0.03]
WINDOW_SIZE_RANGE = [0.15, 0.20, 0.25]
# STEP_RANGE = [0.01, 0.02, 0.03]
# WINDOW_SIZE_RANGE = [0.15, 0.20]
STEP_RANGE = [0.03, 0.05, 0.07, 0.1]
STEP_LIST = [r for t in STEP_RANGE for r in [t]*len(WINDOW_SIZE_RANGE)]
# STEP_LIST
# STEP_LIST = (
# [0.01] * len(WINDOW_SIZE_RANGE)
# + [0.02] * len(WINDOW_SIZE_RANGE)
# + [0.03] * len(WINDOW_SIZE_RANGE)
# )
WINDOW_SIZE_LIST = WINDOW_SIZE_RANGE * len(STEP_RANGE)
def offline_inference(wave_file, STEP=0.025, WINDOW_SIZE=0.5, threshold=0.5):
FRAME_LEN = STEP # infer every STEP seconds
CHANNELS = 1 # number of audio channels (expect mono signal)
RATE = 16000 # sample rate, Hz
CHUNK_SIZE = int(FRAME_LEN * RATE)
vad = FrameVAD(
model_definition={
"sample_rate": SAMPLE_RATE,
"AudioToMFCCPreprocessor": cfg.preprocessor,
"JasperEncoder": cfg.encoder,
"labels": cfg.labels,
},
threshold=threshold,
frame_len=FRAME_LEN,
frame_overlap=(WINDOW_SIZE - FRAME_LEN) / 2,
offset=0,
)
wf = wave.open(wave_file, "rb")
# p = pa.PyAudio()
empty_counter = 0
preds = []
proba_b = []
proba_s = []
# stream = p.open(
# format=p.get_format_from_width(wf.getsampwidth()),
# channels=CHANNELS,
# rate=RATE,
# output=True,
# )
data = wf.readframes(CHUNK_SIZE)
while len(data) > 0:
data = wf.readframes(CHUNK_SIZE)
signal = np.frombuffer(data, dtype=np.int16)
result = vad.transcribe(signal)
preds.append(result[0])
proba_b.append(result[2])
proba_s.append(result[3])
if len(result):
# print(result, end="\n")
empty_counter = 3
elif empty_counter > 0:
empty_counter -= 1
# if empty_counter == 0:
# print(" ", end="")
# p.terminate()
vad.reset()
return preds, proba_b, proba_s
# demo_wave = "VAD_demo.wav"
# if not os.path.exists(demo_wave):
# wget.download(
# "https://dldata-public.s3.us-east-2.amazonaws.com/VAD_demo.wav",
# demo_wave,
# )
demo_wave = "WAL-1201-cust.wav"
wave_file = demo_wave
CHANNELS = 1
RATE = 16000
audio, sample_rate = librosa.load(wave_file, sr=RATE)
dur = librosa.get_duration(audio)
print(dur)
threshold = 0.5
results = []
for STEP, WINDOW_SIZE in zip(
STEP_LIST,
WINDOW_SIZE_LIST,
):
print(f"====== STEP is {STEP}s, WINDOW_SIZE is {WINDOW_SIZE}s ====== ")
preds, proba_b, proba_s = offline_inference(
wave_file, STEP, WINDOW_SIZE, threshold
)
results.append([STEP, WINDOW_SIZE, preds, proba_b, proba_s])
plt.figure(figsize=[20, 3*len(STEP_LIST)])
num = len(results)
for i in range(num):
len_pred = len(results[i][2])
FRAME_LEN = results[i][0]
ax1 = plt.subplot(num + 1, 1, i + 1)
ax1.plot(np.arange(audio.size) / sample_rate, audio, "b")
ax1.set_xlim([-0.01, int(dur) + 1])
ax1.tick_params(axis="y", labelcolor="b")
ax1.set_ylabel("Signal")
ax1.set_ylim([-1, 1])
proba_s = results[i][4]
pred = [1 if p > threshold else 0 for p in proba_s]
ax2 = ax1.twinx()
ax2.plot(
np.arange(len_pred) / (1 / results[i][0]),
np.array(pred),
"r",
label="pred",
)
ax2.plot(
np.arange(len_pred) / (1 / results[i][0]),
np.array(proba_s),
"g--",
label="speech prob",
)
ax2.tick_params(axis="y", labelcolor="r")
legend = ax2.legend(loc="lower right", shadow=True)
ax1.set_ylabel("prediction")
ax2.set_title(f"step {results[i][0]}s, buffer size {results[i][1]}s")
ax2.set_ylabel("Preds and Probas")
ax = plt.subplot(num + 1, 1, i + 2)
S = librosa.feature.melspectrogram(
y=audio, sr=sample_rate, n_mels=64, fmax=8000
)
S_dB = librosa.power_to_db(S, ref=np.max)
librosa.display.specshow(
S_dB, x_axis="time", y_axis="mel", sr=sample_rate, fmax=8000
)
ax.set_title("Mel-frequency spectrogram")
ax.grid()
plt.show()
ipd.Audio(data=audio, rate=sample_rate)

View File

@ -55,10 +55,8 @@ def export_jasper(src_dataset_path: Path, dest_dataset_path: Path, unlink: bool
out_tsv = dest_dataset_path / Path(o_tsv)
out_ltr = dest_dataset_path / Path(o_ltr)
with out_tsv.open("w") as tsv_f, out_ltr.open("w") as ltr_f:
if unlink:
tsv_f.write(f"{dest_dataset_path}\n")
else:
tsv_f.write(f"{src_dataset_path}\n")
dest_path = dest_dataset_path if unlink else src_dataset_path
tsv_f.write(f"{dest_path}\n")
for md in manifest_data:
audio_fname = md["audio_filepath"]
pipe_toks = replace_redundant_spaces_with(md["text"], "|").upper()

View File

@ -18,7 +18,9 @@ def ui():
@app.command()
def annotation(data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""):
def annotation(
data_dir: Path, dump_fname: Path = "ui_dump.json", task_id: str = ""
):
annotation_lit_path = Path(__file__).parent / Path("annotation.py")
if task_id:
sys.argv = [
@ -83,6 +85,13 @@ def audio(audio_dir: Path):
sys.exit(stcli.main())
@app.command()
def slu_infer():
lit_path = Path(__file__).parent / Path("slu_infer.py")
sys.argv = ["streamlit", "run", str(lit_path)]
sys.exit(stcli.main())
@app.command()
def collection(data_dir: Path, task_id: str = ""):
# TODO: Implement web ui for data collection

32
src/plume/ui/slu_infer.py Normal file
View File

@ -0,0 +1,32 @@
# from pathlib import Path
import streamlit as st
import typer
from plume.utils.transcribe import triton_transcribe_grpc_gen
from plume.utils.audio import audio_wav_bytes_to_seg
app = typer.Typer()
transcriber, prep = triton_transcribe_grpc_gen(
asr_model="slu_num_wav2vec2", method="whole", append_raw=True
)
@app.command()
def main():
st.title("SLU Inference")
audio_file = st.file_uploader("Upload File", type=["wav", "mp3"])
if audio_file:
audio_bytes = audio_file.read()
seg = audio_wav_bytes_to_seg(audio_bytes)
st.audio(audio_bytes)
tscript = transcriber(prep(seg))
st.write(tscript)
if __name__ == "__main__":
try:
app()
except SystemExit:
pass

View File

@ -32,7 +32,11 @@ import six
# from .transcribe import triton_transcribe_grpc_gen
# from .eval import app as eval_app
from .manifest import asr_manifest_writer, manifest_str
from .manifest import (
asr_manifest_writer,
asr_manifest_reader,
manifest_str,
) # noqa
from .lazy_import import lazy_callable, lazy_module
from .parallel import parallel_apply
from .extended_path import ExtendedPath
@ -430,17 +434,6 @@ def ui_dump_manifest_writer(dataset_dir, asr_data_source, verbose=False):
return num_datapoints
def asr_manifest_reader(data_manifest_path: Path):
print(f"reading manifest from {data_manifest_path}")
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
data_data = [json.loads(v) for v in data_jsonl]
for p in data_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["text"] = p["text"].strip()
yield p
def asr_test_writer(out_file_path: Path, source):
def dd_str(dd, idx):
path = dd["audio_filepath"]
@ -558,11 +551,19 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
blank_count += 1
typer.echo(f"filtered {blank_count} of {total_count} blank samples")
def filtered_max_sample_dur():
def filtered_maxmin_sample_dur():
import soundfile
max_dur_count = 0
for s in src_data_enum:
wav_duration = s["duration"]
if wav_duration <= max_sample_dur:
wav_real_duration = soundfile.info(
src_dataset_path / Path(s["audio_filepath"])
).duration
wav_duration = min(wav_real_duration, s["duration"])
if (
wav_duration <= max_sample_dur
and wav_duration > min_sample_dur
):
shutil.copy(
src_dataset_path / Path(s["audio_filepath"]),
dest_dataset_path / Path(s["audio_filepath"]),
@ -571,7 +572,7 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
else:
max_dur_count += 1
typer.echo(
f"filtered {max_dur_count} samples longer thans {max_sample_dur}s"
f"filtered {max_dur_count} samples longer thans {max_sample_dur}s and shorter than {min_sample_dur}s"
)
def filtered_transform_digits():
@ -641,7 +642,9 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
wav_duration = 0
for s in src_data_enum:
# nums = re.sub(" ", "", s["text"])
s["text"] = "gAAAAABgq2FR6ajbhMsDmWRQBzX6gIzyAG5sMwFihGeV7E_6eVJqqF78yzmtTJPsJAOJEEXhJ9Z45MrYNgE1sq7VUdsBVGh2cw=="
s[
"text"
] = "gAAAAABgq2FR6ajbhMsDmWRQBzX6gIzyAG5sMwFihGeV7E_6eVJqqF78yzmtTJPsJAOJEEXhJ9Z45MrYNgE1sq7VUdsBVGh2cw=="
if (
s["duration"] >= min_sample_dur
and s["duration"] <= max_sample_dur
@ -663,7 +666,7 @@ def generate_filter_map(src_dataset_path, dest_dataset_path, data_file):
"transform_digits": filtered_transform_digits,
"extract_chars": filtered_extract_chars,
"resample_ulaw24kmono": filtered_resample,
"max_sample_dur": filtered_max_sample_dur,
"maxmin_sample_dur": filtered_maxmin_sample_dur,
"msec_to_sec": filtered_msec_to_sec,
"blank_3hr_max_dur": filtered_blank_hr_max_dur,
}

View File

@ -13,7 +13,7 @@ from .audio import audio_seg_to_wav_bytes, audio_wav_bytes_to_seg
from .parallel import parallel_apply
from .lazy_import import lazy_module
cryptography = lazy_module("cryptography")
cryptography = lazy_module("cryptography.fernet", level='base')
# cryptography.fernet = lazy_module("cryptography.fernet")
pydub = lazy_module("pydub")

View File

@ -13,6 +13,17 @@ def manifest_str(path, dur, text):
return json.dumps(k) + "\n"
def asr_manifest_reader(data_manifest_path: Path):
print(f"reading manifest from {data_manifest_path}")
with data_manifest_path.open("r") as pf:
data_jsonl = pf.readlines()
data_data = [json.loads(v) for v in data_jsonl]
for p in data_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["text"] = p["text"].strip()
yield p
def asr_manifest_writer(
asr_manifest_path: Path, manifest_str_source, verbose=False
):

View File

@ -104,6 +104,8 @@ def triton_transcribe_grpc_gen(
if len(outputs) > 1 and append_raw:
transcript = transcript + "|" + outputs[1].decode("utf-8")
except InferenceServerException:
import traceback
traceback.print_exc()
transcript = "[server error]"
return transcript
@ -146,7 +148,7 @@ def triton_transcribe_grpc_gen(
@app.command()
def file(
def audio_file(
audio_file: Path,
write_file: bool = False,
chunked: bool = False,