diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..170a050 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +exclude = docs +ignore = E203, W503 +max-line-length = 119 diff --git a/jasper/data/utils.py b/jasper/data/utils.py index 5409e1c..300a2da 100644 --- a/jasper/data/utils.py +++ b/jasper/data/utils.py @@ -58,83 +58,88 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False): return num_datapoints -def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False): +def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False): dataset_dir = output_dir / Path(dataset_name) (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) - ui_dump_file = dataset_dir / Path("ui_dump.json") (dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True) - asr_manifest = dataset_dir / Path("manifest.json") + + def data_fn( + transcript, + audio_dur, + wav_data, + caller_name, + aud_seg, + fname, + audio_path, + num_datapoints, + rel_data_path, + ): + pretrained_result = transcriber_pretrained(aud_seg.raw_data) + pretrained_wer = word_error_rate([transcript], [pretrained_result]) + png_path = Path(fname).with_suffix(".png") + wav_plot_path = dataset_dir / Path("wav_plots") / png_path + if not wav_plot_path.exists(): + plot_seg(wav_plot_path, audio_path) + return { + "audio_filepath": str(rel_data_path), + "duration": round(audio_dur, 1), + "text": transcript, + "real_idx": num_datapoints, + "audio_path": audio_path, + "spoken": transcript, + "caller": caller_name, + "utterance_id": fname, + "pretrained_asr": pretrained_result, + "pretrained_wer": pretrained_wer, + "plot_path": str(wav_plot_path), + } + num_datapoints = 0 - ui_dump = { - "use_domain_asr": False, - "annotation_only": False, - "enable_plots": True, - "data": [], - } data_funcs = [] transcriber_pretrained = transcribe_gen(asr_port=8044) + for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source: + fname = str(uuid4()) + "_" + slugify(transcript, max_length=8) + audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") + audio_file.write_bytes(wav_data) + audio_path = str(audio_file) + rel_data_path = audio_file.relative_to(dataset_dir) + data_funcs.append( + partial( + data_fn, + transcript, + audio_dur, + wav_data, + caller_name, + aud_seg, + fname, + audio_path, + num_datapoints, + rel_data_path, + ) + ) + num_datapoints += 1 + ui_data = parallel_apply(lambda x: x(), data_funcs) + return ui_data, num_datapoints + + +def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False): + dataset_dir = output_dir / Path(dataset_name) + dump_data, num_datapoints = ui_data_generator( + output_dir, dataset_name, asr_data_source, verbose=verbose + ) + + asr_manifest = dataset_dir / Path("manifest.json") with asr_manifest.open("w") as mf: print(f"writing manifest to {asr_manifest}") - - def data_fn( - transcript, - audio_dur, - wav_data, - caller_name, - aud_seg, - fname, - audio_path, - num_datapoints, - rel_data_path, - ): - pretrained_result = transcriber_pretrained(aud_seg.raw_data) - pretrained_wer = word_error_rate([transcript], [pretrained_result]) - wav_plot_path = ( - dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png") - ) - if not wav_plot_path.exists(): - plot_seg(wav_plot_path, audio_path) - return { - "audio_filepath": str(rel_data_path), - "duration": round(audio_dur, 1), - "text": transcript, - "real_idx": num_datapoints, - "audio_path": audio_path, - "spoken": transcript, - "caller": caller_name, - "utterance_id": fname, - "pretrained_asr": pretrained_result, - "pretrained_wer": pretrained_wer, - "plot_path": str(wav_plot_path), - } - - for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source: - fname = str(uuid4()) + "_" + slugify(transcript, max_length=8) - audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") - audio_file.write_bytes(wav_data) - audio_path = str(audio_file) - rel_data_path = audio_file.relative_to(dataset_dir) + for d in dump_data: + rel_data_path = d["audio_filepath"] + audio_dur = d["duration"] + transcript = d["text"] manifest = manifest_str(str(rel_data_path), audio_dur, transcript) mf.write(manifest) - data_funcs.append( - partial( - data_fn, - transcript, - audio_dur, - wav_data, - caller_name, - aud_seg, - fname, - audio_path, - num_datapoints, - rel_data_path, - ) - ) - num_datapoints += 1 - dump_data = parallel_apply(lambda x: x(), data_funcs) - # dump_data = [x() for x in tqdm(data_funcs)] - ui_dump["data"] = dump_data - ExtendedPath(ui_dump_file).write_json(ui_dump) + + ui_dump_file = dataset_dir / Path("ui_dump.json") + ExtendedPath(ui_dump_file).write_json({"data": dump_data}) return num_datapoints diff --git a/jasper/data/validation/process.py b/jasper/data/validation/process.py index f887c27..619113b 100644 --- a/jasper/data/validation/process.py +++ b/jasper/data/validation/process.py @@ -1,13 +1,10 @@ import json import shutil from pathlib import Path -from enum import Enum import typer -from tqdm import tqdm from ..utils import ( - alnum_to_asr_tokens, ExtendedPath, asr_manifest_reader, asr_manifest_writer, @@ -19,9 +16,7 @@ from ..utils import ( app = typer.Typer() -def preprocess_datapoint( - idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots -): +def preprocess_datapoint(idx, rel_root, sample): from pydub import AudioSegment from nemo.collections.asr.metrics import word_error_rate from jasper.client import transcribe_gen @@ -31,37 +26,23 @@ def preprocess_datapoint( res["real_idx"] = idx audio_path = rel_root / Path(sample["audio_filepath"]) res["audio_path"] = str(audio_path) - if use_domain_asr: - res["spoken"] = alnum_to_asr_tokens(res["text"]) - else: - res["spoken"] = res["text"] res["utterance_id"] = audio_path.stem - if not annotation_only: - transcriber_pretrained = transcribe_gen(asr_port=8044) + transcriber_pretrained = transcribe_gen(asr_port=8044) - aud_seg = ( - AudioSegment.from_file_using_temporary_files(audio_path) - .set_channels(1) - .set_sample_width(2) - .set_frame_rate(24000) - ) - res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) - res["pretrained_wer"] = word_error_rate( - [res["text"]], [res["pretrained_asr"]] - ) - if use_domain_asr: - transcriber_speller = transcribe_gen(asr_port=8045) - res["domain_asr"] = transcriber_speller(aud_seg.raw_data) - res["domain_wer"] = word_error_rate( - [res["spoken"]], [res["pretrained_asr"]] - ) - if enable_plots: - wav_plot_path = ( - rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png") - ) - if not wav_plot_path.exists(): - plot_seg(wav_plot_path, audio_path) - res["plot_path"] = str(wav_plot_path) + aud_seg = ( + AudioSegment.from_file_using_temporary_files(audio_path) + .set_channels(1) + .set_sample_width(2) + .set_frame_rate(24000) + ) + res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) + res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]]) + wav_plot_path = ( + rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png") + ) + if not wav_plot_path.exists(): + plot_seg(wav_plot_path, audio_path) + res["plot_path"] = str(wav_plot_path) return res except BaseException as e: print(f'failed on {idx}: {sample["audio_filepath"]} with {e}') @@ -73,61 +54,50 @@ def dump_ui( dataset_dir: Path = Path("./data/asr_data"), dump_dir: Path = Path("./data/valiation_data"), dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True), - use_domain_asr: bool = False, - annotation_only: bool = False, - enable_plots: bool = True, ): - from concurrent.futures import ThreadPoolExecutor - from functools import partial + from io import BytesIO + from pydub import AudioSegment + from ..utils import ui_data_generator data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json") - dump_path: Path = dump_dir / Path(data_name) / dump_fname plot_dir = data_manifest_path.parent / Path("wav_plots") plot_dir.mkdir(parents=True, exist_ok=True) typer.echo(f"Using data manifest:{data_manifest_path}") - with data_manifest_path.open("r") as pf: - data_jsonl = pf.readlines() - data_funcs = [ - partial( - preprocess_datapoint, - i, - data_manifest_path.parent, - json.loads(v), - use_domain_asr, - annotation_only, - enable_plots, - ) - for i, v in enumerate(data_jsonl) - ] - def exec_func(f): - return f() + def asr_data_source_gen(): + with data_manifest_path.open("r") as pf: + data_jsonl = pf.readlines() + for v in data_jsonl: + sample = json.loads(v) + rel_root = data_manifest_path.parent + res = dict(sample) + audio_path = rel_root / Path(sample["audio_filepath"]) + audio_segment = ( + AudioSegment.from_file_using_temporary_files(audio_path) + .set_channels(1) + .set_sample_width(2) + .set_frame_rate(24000) + ) + wav_plot_path = ( + rel_root + / Path("wav_plots") + / Path(audio_path.name).with_suffix(".png") + ) + if not wav_plot_path.exists(): + plot_seg(wav_plot_path, audio_path) + res["plot_path"] = str(wav_plot_path) + code_fb = BytesIO() + audio_segment.export(code_fb, format="wav") + wav_data = code_fb.getvalue() + duration = audio_segment.duration_seconds + asr_final = res["text"] + yield asr_final, duration, wav_data, "caller", audio_segment - with ThreadPoolExecutor() as exe: - print("starting all preprocess tasks") - data_final = filter( - None, - list( - tqdm( - exe.map(exec_func, data_funcs), - position=0, - leave=True, - total=len(data_funcs), - ) - ), - ) - if annotation_only: - result = list(data_final) - else: - wer_key = "domain_wer" if use_domain_asr else "pretrained_wer" - result = sorted(data_final, key=lambda x: x[wer_key], reverse=True) - ui_config = { - "use_domain_asr": use_domain_asr, - "annotation_only": annotation_only, - "enable_plots": enable_plots, - "data": result, - } - ExtendedPath(dump_path).write_json(ui_config) + dump_data, num_datapoints = ui_data_generator( + dataset_dir, data_name, asr_data_source_gen() + ) + ui_dump_file = dataset_dir / Path("ui_dump.json") + ExtendedPath(ui_dump_file).write_json({"data": dump_data}) @app.command() @@ -190,7 +160,9 @@ def dump_corrections( col = get_mongo_conn(col="asr_validation") task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0] corrections = list(col.find({"type": "correction"}, projection={"_id": False})) - cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False}) + cursor_obj = col.find( + {"type": "correction", "task_id": task_id}, projection={"_id": False} + ) corrections = [c for c in cursor_obj] ExtendedPath(dump_path).write_json(corrections) @@ -264,7 +236,9 @@ def split_extract( dump_file: Path = Path("ui_dump.json"), manifest_file: Path = Path("manifest.json"), corrections_file: str = typer.Option("corrections.json", show_default=True), - conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True), + conv_data_path: Path = typer.Option( + Path("./data/conv_data.json"), show_default=True + ), extraction_type: str = "all", ): import shutil @@ -286,7 +260,9 @@ def split_extract( def extract_manifest(mg): for m in mg: if m["text"] in extraction_vals: - shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"])) + shutil.copy( + m["audio_path"], dest_data_dir / Path(m["audio_filepath"]) + ) yield m asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen)) @@ -295,12 +271,14 @@ def split_extract( orig_ui_data = ExtendedPath(ui_data_path).read_json() ui_data = orig_ui_data["data"] file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data} - extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data)) + extracted_ui_data = list( + filter(lambda u: u["text"] in extraction_vals, ui_data) + ) final_data = [] for i, d in enumerate(extracted_ui_data): - d['real_idx'] = i + d["real_idx"] = i final_data.append(d) - orig_ui_data['data'] = final_data + orig_ui_data["data"] = final_data ExtendedPath(dest_ui_path).write_json(orig_ui_data) if corrections_file: @@ -316,7 +294,7 @@ def split_extract( ) ExtendedPath(dest_correction_path).write_json(extracted_corrections) - if extraction_type.value == 'all': + if extraction_type.value == "all": for ext_key in conv_data.keys(): extract_data_of_type(ext_key) else: @@ -338,7 +316,7 @@ def update_corrections( def correct_manifest(ui_dump_path, corrections_path): corrections = ExtendedPath(corrections_path).read_json() - ui_data = ExtendedPath(ui_dump_path).read_json()['data'] + ui_data = ExtendedPath(ui_dump_path).read_json()["data"] correct_set = { c["code"] for c in corrections if c["value"]["status"] == "Correct" } @@ -367,7 +345,9 @@ def update_corrections( ) else: orig_audio_path = Path(d["audio_path"]) - new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")) + new_name = str( + Path(tscript_uuid_fname(correct_text)).with_suffix(".wav") + ) new_audio_path = orig_audio_path.with_name(new_name) orig_audio_path.replace(new_audio_path) new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) diff --git a/jasper/data/validation/ui.py b/jasper/data/validation/ui.py index 00f2e5c..8d6a72c 100644 --- a/jasper/data/validation/ui.py +++ b/jasper/data/validation/ui.py @@ -72,22 +72,18 @@ def main(manifest: Path, task_id: str = ""): st.set_task(manifest, task_id) ui_config = load_ui_data(manifest) asr_data = ui_config["data"] - use_domain_asr = ui_config.get("use_domain_asr", True) annotation_only = ui_config.get("annotation_only", False) - enable_plots = ui_config.get("enable_plots", True) sample_no = st.get_current_cursor() if len(asr_data) - 1 < sample_no or sample_no < 0: print("Invalid samplno resetting to 0") st.update_cursor(0) sample = asr_data[sample_no] - title_type = "Speller " if use_domain_asr else "" task_uid = st.task_id.rsplit("-", 1)[1] if annotation_only: st.title(f"ASR Annotation - # {task_uid}") else: - st.title(f"ASR {title_type}Validation - # {task_uid}") - addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else "" - st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text) + st.title(f"ASR Validation - # {task_uid}") + st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**") new_sample = st.number_input( "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) ) @@ -96,19 +92,13 @@ def main(manifest: Path, task_id: str = ""): st.sidebar.title(f"Details: [{sample['real_idx']}]") st.sidebar.markdown(f"Gold Text: **{sample['text']}**") if not annotation_only: - if use_domain_asr: - st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*") st.sidebar.title("Results:") st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**") if "caller" in sample: st.sidebar.markdown(f"Caller: **{sample['caller']}**") - if use_domain_asr: - st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**") - st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%") else: st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") - if enable_plots: - st.sidebar.image(Path(sample["plot_path"]).read_bytes()) + st.sidebar.image(Path(sample["plot_path"]).read_bytes()) st.audio(Path(sample["audio_path"]).open("rb")) # set default to text corrected = sample["text"] @@ -130,16 +120,12 @@ def main(manifest: Path, task_id: str = ""): ) st.update_cursor(sample_no + 1) if correction_entry: - st.markdown( - f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' - ) + status = correction_entry["value"]["status"] + correction = correction_entry["value"]["correction"] + st.markdown(f"Your Response: **{status}** Correction: **{correction}**") text_sample = st.text_input("Go to Text:", value="") if text_sample != "": - candidates = [ - i - for (i, p) in enumerate(asr_data) - if p["text"] == text_sample or p["spoken"] == text_sample - ] + candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample] if len(candidates) > 0: st.update_cursor(candidates[0]) real_idx = st.number_input(