From 1acf9e403c6f2dbdc37317eba361b7fac6cf376b Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Wed, 27 May 2020 15:19:25 +0530 Subject: [PATCH] 1. added support for mono/dual channel rev transcripts 2. handle errors when extracting datapoints from rev meta data 3. added suport for annotation only task when dumping ui data --- jasper/data/rev_recycler.py | 36 ++++++++++++++++++++- jasper/data/validation/__init__.py | 1 + jasper/data/validation/process.py | 52 +++++++++++++++++++----------- jasper/data/validation/ui.py | 31 +++++++++--------- 4 files changed, 86 insertions(+), 34 deletions(-) create mode 100644 jasper/data/validation/__init__.py diff --git a/jasper/data/rev_recycler.py b/jasper/data/rev_recycler.py index bad1ab7..6871505 100644 --- a/jasper/data/rev_recycler.py +++ b/jasper/data/rev_recycler.py @@ -88,7 +88,7 @@ def extract_data( - datetime.datetime(1900, 1, 1) ).total_seconds() * 1000 - def asr_data_generator(wav_seg, wav_path, meta): + def dual_asr_data_generator(wav_seg, wav_path, meta): left_audio, right_audio = wav_seg.split_to_mono() channel_map = {"Agent": right_audio, "Client": left_audio} monologues = lens["monologues"].Each().collect()(meta) @@ -113,6 +113,7 @@ def extract_data( ) except IndexError: print(f'error when loading timestamp events in wav:{wav_path} skipping.') + continue # offset by 500 msec to include first vad? discarded audio full_tscript_wav_seg = speaker_channel[time_to_msecs(start_time) - 500 : time_to_msecs(end_time)] @@ -124,10 +125,43 @@ def extract_data( text_clean = re.sub(r"\[.*\]", "", text) yield text_clean, tscript_wav_seg.duration_seconds, tscript_wav + def mono_asr_data_generator(wav_seg, wav_path, meta): + monologues = lens["monologues"].Each().collect()(meta) + for monologue in monologues: + try: + start_time = ( + lens["elements"] + .Each() + .Filter(lambda x: "timestamp" in x)["timestamp"] + .collect()(monologue)[0] + ) + end_time = ( + lens["elements"] + .Each() + .Filter(lambda x: "end_timestamp" in x)["end_timestamp"] + .collect()(monologue)[-1] + ) + except IndexError: + print(f'error when loading timestamp events in wav:{wav_path} skipping.') + continue + + # offset by 500 msec to include first vad? discarded audio + full_tscript_wav_seg = wav_seg[time_to_msecs(start_time) - 500 : time_to_msecs(end_time)] + tscript_wav_seg = strip_silence(full_tscript_wav_seg) + tscript_wav_fb = BytesIO() + tscript_wav_seg.export(tscript_wav_fb, format="wav") + tscript_wav = tscript_wav_fb.getvalue() + text = "".join(lens["elements"].Each()["value"].collect()(monologue)) + text_clean = re.sub(r"\[.*\]", "", text) + yield text_clean, tscript_wav_seg.duration_seconds, tscript_wav + def generate_rev_asr_data(): full_asr_data = [] total_duration = 0 for wav, wav_path, ev in wav_event_generator(call_audio_dir): + if wav.channels > 2: + print(f'skipping many channel audio {wav_path}') + asr_data_generator = mono_asr_data_generator if wav.channels == 1 else dual_asr_data_generator asr_data = asr_data_generator(wav, wav_path, ev) total_duration += wav.duration_seconds full_asr_data.append(asr_data) diff --git a/jasper/data/validation/__init__.py b/jasper/data/validation/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/jasper/data/validation/__init__.py @@ -0,0 +1 @@ + diff --git a/jasper/data/validation/process.py b/jasper/data/validation/process.py index 3ea1233..b9c7d2c 100644 --- a/jasper/data/validation/process.py +++ b/jasper/data/validation/process.py @@ -16,16 +16,12 @@ from ..utils import ( app = typer.Typer() -def preprocess_datapoint(idx, rel_root, sample, use_domain_asr): +def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only): import matplotlib.pyplot as plt import librosa import librosa.display from pydub import AudioSegment from nemo.collections.asr.metrics import word_error_rate - from jasper.client import ( - transcriber_pretrained, - transcriber_speller, - ) try: res = dict(sample) @@ -40,13 +36,18 @@ def preprocess_datapoint(idx, rel_root, sample, use_domain_asr): .set_sample_width(2) .set_frame_rate(24000) ) - res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) - res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]]) - if use_domain_asr: - res["domain_asr"] = transcriber_speller(aud_seg.raw_data) - res["domain_wer"] = word_error_rate( - [res["spoken"]], [res["pretrained_asr"]] + if not annotation_only: + from jasper.client import transcriber_pretrained, transcriber_speller + + res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) + res["pretrained_wer"] = word_error_rate( + [res["text"]], [res["pretrained_asr"]] ) + if use_domain_asr: + res["domain_asr"] = transcriber_speller(aud_seg.raw_data) + res["domain_wer"] = word_error_rate( + [res["spoken"]], [res["pretrained_asr"]] + ) wav_plot_path = ( rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png") ) @@ -67,9 +68,14 @@ def preprocess_datapoint(idx, rel_root, sample, use_domain_asr): @app.command() def dump_validation_ui_data( - data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), - dump_path: Path = Path("./data/valiation_data/ui_dump.json"), + data_manifest_path: Path = typer.Option( + Path("./data/asr_data/call_alphanum/manifest.json"), show_default=True + ), + dump_path: Path = typer.Option( + Path("./data/valiation_data/ui_dump.json"), show_default=True + ), use_domain_asr: bool = True, + annotation_only: bool = True, ): from concurrent.futures import ThreadPoolExecutor from functools import partial @@ -86,6 +92,7 @@ def dump_validation_ui_data( data_manifest_path.parent, json.loads(v), use_domain_asr, + annotation_only, ) for i, v in enumerate(pnr_jsonl) ] @@ -94,7 +101,7 @@ def dump_validation_ui_data( return f() with ThreadPoolExecutor() as exe: - print("starting all plot tasks") + print("starting all preprocess tasks") pnr_data = filter( None, list( @@ -106,9 +113,16 @@ def dump_validation_ui_data( ) ), ) - wer_key = "domain_wer" if use_domain_asr else "pretrained_wer" - result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True) - ui_config = {"use_domain_asr": use_domain_asr, "data": result} + if annotation_only: + result = pnr_data + else: + wer_key = "domain_wer" if use_domain_asr else "pretrained_wer" + result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True) + ui_config = { + "use_domain_asr": use_domain_asr, + "data": result, + "annotation_only": annotation_only, + } ExtendedPath(dump_path).write_json(ui_config) @@ -171,7 +185,9 @@ def update_corrections( elif d["chars"] in correction_map: correct_text = correction_map[d["chars"]] if skip_incorrect: - print(f'skipping incorrect {d["audio_path"]} corrected to {correct_text}') + print( + f'skipping incorrect {d["audio_path"]} corrected to {correct_text}' + ) else: renamed_set.add(correct_text) new_name = str(Path(correct_text).with_suffix(".wav")) diff --git a/jasper/data/validation/ui.py b/jasper/data/validation/ui.py index 04937f1..7ba67d3 100644 --- a/jasper/data/validation/ui.py +++ b/jasper/data/validation/ui.py @@ -61,14 +61,18 @@ def load_ui_data(validation_ui_data_path: Path): def main(manifest: Path): ui_config = load_ui_data(manifest) asr_data = ui_config["data"] - use_domain_asr = ui_config["use_domain_asr"] + use_domain_asr = ui_config.get("use_domain_asr", True) + annotation_only = ui_config.get("annotation_only", False) sample_no = st.get_current_cursor() if len(asr_data) - 1 < sample_no or sample_no < 0: print("Invalid samplno resetting to 0") st.update_cursor(0) sample = asr_data[sample_no] title_type = "Speller " if use_domain_asr else "" - st.title(f"ASR {title_type}Validation") + if annotation_only: + st.title(f"ASR Annotation") + else: + st.title(f"ASR {title_type}Validation") addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else "" st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text) new_sample = st.number_input( @@ -78,15 +82,16 @@ def main(manifest: Path): st.update_cursor(new_sample - 1) st.sidebar.title(f"Details: [{sample['real_idx']}]") st.sidebar.markdown(f"Gold Text: **{sample['text']}**") - if use_domain_asr: - st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*") - st.sidebar.title("Results:") - st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**") - if use_domain_asr: - st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**") - st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%") - else: - st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") + if not annotation_only: + if use_domain_asr: + st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*") + st.sidebar.title("Results:") + st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**") + if use_domain_asr: + st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**") + st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%") + else: + st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") st.sidebar.image(Path(sample["plot_path"]).read_bytes()) st.audio(Path(sample["audio_path"]).open("rb")) # set default to text @@ -113,10 +118,6 @@ def main(manifest: Path): st.markdown( f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' ) - # if st.button("Previous Untagged"): - # pass - # if st.button("Next Untagged"): - # pass text_sample = st.text_input("Go to Text:", value='') if text_sample != '': candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample]