diff --git a/jasper/data_utils/process.py b/jasper/data_utils/process.py index 7e38523..5d184f7 100644 --- a/jasper/data_utils/process.py +++ b/jasper/data_utils/process.py @@ -1,7 +1,7 @@ import json from pathlib import Path from sklearn.model_selection import train_test_split -from .utils import alnum_to_asr_tokens +from .utils import alnum_to_asr_tokens, asr_manifest_reader, asr_manifest_writer import typer app = typer.Typer() @@ -30,35 +30,38 @@ def separate_space_convert_digit_setpath(): @app.command() -def split_data(manifest_path: Path = Path("/dataset/asr_data/pnr_data/pnr_data.json")): - with manifest_path.open("r") as pf: - pnr_jsonl = pf.readlines() - train_pnr, test_pnr = train_test_split(pnr_jsonl, test_size=0.1) - with (manifest_path.parent / Path("train_manifest.json")).open("w") as pf: - pnr_data = "".join(train_pnr) - pf.write(pnr_data) - with (manifest_path.parent / Path("test_manifest.json")).open("w") as pf: - pnr_data = "".join(test_pnr) - pf.write(pnr_data) +def split_data(dataset_path: Path, test_size: float = 0.1): + manifest_path = dataset_path / Path("abs_manifest.json") + asr_data = list(asr_manifest_reader(manifest_path)) + train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size) + asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr) + asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr) @app.command() -def fix_path( - dataset_path: Path = Path("/dataset/asr_data/call_alphanum"), -): - manifest_path = dataset_path / Path('manifest.json') - with manifest_path.open("r") as pf: - pnr_jsonl = pf.readlines() - pnr_data = [json.loads(i) for i in pnr_jsonl] - new_pnr_data = [] - for i in pnr_data: +def fixate_data(dataset_path: Path): + manifest_path = dataset_path / Path("manifest.json") + real_manifest_path = dataset_path / Path("abs_manifest.json") + + def fix_path(): + for i in asr_manifest_reader(manifest_path): i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"])) - new_pnr_data.append(i) - new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data] - real_manifest_path = dataset_path / Path('real_manifest.json') - with real_manifest_path.open("w") as pf: - new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n" - pf.write(new_pnr_data) + yield i + + asr_manifest_writer(real_manifest_path, fix_path()) + + # with manifest_path.open("r") as pf: + # pnr_jsonl = pf.readlines() + # pnr_data = [json.loads(i) for i in pnr_jsonl] + # new_pnr_data = [] + # for i in pnr_data: + # i["audio_filepath"] = str(dataset_path / Path(i["audio_filepath"])) + # new_pnr_data.append(i) + # new_pnr_jsonl = [json.dumps(i) for i in new_pnr_data] + # real_manifest_path = dataset_path / Path("abs_manifest.json") + # with real_manifest_path.open("w") as pf: + # new_pnr_data = "\n".join(new_pnr_jsonl) # + "\n" + # pf.write(new_pnr_data) @app.command() @@ -78,14 +81,18 @@ def augment_an4(): @app.command() -def validate_data(data_file: Path = Path("/dataset/asr_data/call_alphanum/train_manifest.json")): +def validate_data(data_file: Path): with Path(data_file).open("r") as pf: pnr_jsonl = pf.readlines() for (i, s) in enumerate(pnr_jsonl): try: - json.loads(s) + d = json.loads(s) + audio_file = data_file.parent / Path(d["audio_filepath"]) + if not audio_file.exists(): + raise OSError(f"File {audio_file} not found") except BaseException as e: - print(f"failed on {i}") + print(f'failed on {i} with "{e}"') + print("no errors found. seems like a valid manifest.") def main(): diff --git a/jasper/data_utils/utils.py b/jasper/data_utils/utils.py index aca31e8..36893d0 100644 --- a/jasper/data_utils/utils.py +++ b/jasper/data_utils/utils.py @@ -61,6 +61,27 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source): mf.write(manifest) +def asr_manifest_reader(data_manifest_path: Path): + print(f'reading manifest from {data_manifest_path}') + with data_manifest_path.open("r") as pf: + pnr_jsonl = pf.readlines() + pnr_data = [json.loads(v) for v in pnr_jsonl] + for p in pnr_data: + p['audio_path'] = data_manifest_path.parent / Path(p['audio_filepath']) + p['chars'] = Path(p['audio_filepath']).stem + yield p + + +def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source): + with asr_manifest_path.open("w") as mf: + print(f'opening {asr_manifest_path} for writing manifest') + for mani_dict in manifest_str_source: + manifest = manifest_str( + mani_dict['audio_filepath'], mani_dict['duration'], mani_dict['text'] + ) + mf.write(manifest) + + def main(): for c in random_pnr_generator(): print(c) diff --git a/jasper/data_utils/validation/process.py b/jasper/data_utils/validation/process.py new file mode 100644 index 0000000..2718771 --- /dev/null +++ b/jasper/data_utils/validation/process.py @@ -0,0 +1,176 @@ +import pymongo +import typer + +# import matplotlib.pyplot as plt +from pathlib import Path +import json +import shutil + +# import pandas as pd +from pydub import AudioSegment + +# from .jasper_client import transcriber_pretrained, transcriber_speller +from jasper.data_utils.validation.jasper_client import ( + transcriber_pretrained, + transcriber_speller, +) +from jasper.data_utils.utils import alnum_to_asr_tokens + +# import importlib +# import jasper.data_utils.utils +# importlib.reload(jasper.data_utils.utils) +from jasper.data_utils.utils import asr_manifest_reader, asr_manifest_writer +from nemo.collections.asr.metrics import word_error_rate + +# from tqdm import tqdm as tqdm_base +from tqdm import tqdm + +app = typer.Typer() + + +@app.command() +def dump_corrections(dump_path: Path = Path("./data/corrections.json")): + col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation + + cursor_obj = col.find({"type": "correction"}, projection={"_id": False}) + corrections = [c for c in cursor_obj] + dump_f = dump_path.open("w") + json.dump(corrections, dump_f, indent=2) + dump_f.close() + + +def preprocess_datapoint(idx, rel, sample): + res = dict(sample) + res["real_idx"] = idx + audio_path = rel / Path(sample["audio_filepath"]) + res["audio_path"] = str(audio_path) + res["gold_chars"] = audio_path.stem + res["gold_phone"] = sample["text"] + aud_seg = ( + AudioSegment.from_wav(audio_path) + .set_channels(1) + .set_sample_width(2) + .set_frame_rate(24000) + ) + res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) + res["speller_asr"] = transcriber_speller(aud_seg.raw_data) + res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]]) + return res + + +def load_dataset(data_manifest_path: Path): + typer.echo(f"Using data manifest:{data_manifest_path}") + with data_manifest_path.open("r") as pf: + pnr_jsonl = pf.readlines() + pnr_data = [ + preprocess_datapoint(i, data_manifest_path.parent, json.loads(v)) + for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True)) + ] + result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True) + return result + + +@app.command() +def dump_processed_data( + data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), + dump_path: Path = Path("./data/processed_data.json"), +): + typer.echo(f"Using data manifest:{data_manifest_path}") + with data_manifest_path.open("r") as pf: + pnr_jsonl = pf.readlines() + pnr_data = [ + preprocess_datapoint(i, data_manifest_path.parent, json.loads(v)) + for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True)) + ] + result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True) + dump_path = Path("./data/processed_data.json") + dump_f = dump_path.open("w") + json.dump(result, dump_f, indent=2) + dump_f.close() + + +@app.command() +def fill_unannotated( + processed_data_path: Path = Path("./data/processed_data.json"), + corrections_path: Path = Path("./data/corrections.json"), +): + processed_data = json.load(processed_data_path.open()) + corrections = json.load(corrections_path.open()) + annotated_codes = {c["code"] for c in corrections} + all_codes = {c["gold_chars"] for c in processed_data} + unann_codes = all_codes - annotated_codes + mongo_conn = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation + for c in unann_codes: + mongo_conn.find_one_and_update( + {"type": "correction", "code": c}, + {"$set": {"value": {"status": "Inaudible", "correction": ""}}}, + upsert=True, + ) + + +@app.command() +def update_corrections( + data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), + processed_data_path: Path = Path("./data/processed_data.json"), + corrections_path: Path = Path("./data/corrections.json"), +): + def correct_manifest(manifest_data_gen, corrections_path): + corrections = json.load(corrections_path.open()) + correct_set = { + c["code"] for c in corrections if c["value"]["status"] == "Correct" + } + # incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"} + correction_map = { + c["code"]: c["value"]["correction"] + for c in corrections + if c["value"]["status"] == "Incorrect" + } + # for d in manifest_data_gen: + # if d["chars"] in incorrect_set: + # d["audio_path"].unlink() + renamed_set = set() + for d in manifest_data_gen: + if d["chars"] in correct_set: + yield { + "audio_filepath": d["audio_filepath"], + "duration": d["duration"], + "text": d["text"], + } + elif d["chars"] in correction_map: + correct_text = correction_map[d["chars"]] + renamed_set.add(correct_text) + new_name = str(Path(correct_text).with_suffix(".wav")) + d["audio_path"].replace(d["audio_path"].with_name(new_name)) + new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) + yield { + "audio_filepath": new_filepath, + "duration": d["duration"], + "text": alnum_to_asr_tokens(correct_text), + } + else: + # don't delete if another correction points to an old file + if d["chars"] not in renamed_set: + d["audio_path"].unlink() + else: + print(f'skipping deletion of correction:{d["chars"]}') + + typer.echo(f"Using data manifest:{data_manifest_path}") + dataset_dir = data_manifest_path.parent + dataset_name = dataset_dir.name + backup_dir = dataset_dir.with_name(dataset_name + ".bkp") + if not backup_dir.exists(): + typer.echo(f"backing up to :{backup_dir}") + shutil.copytree(str(dataset_dir), str(backup_dir)) + manifest_gen = asr_manifest_reader(data_manifest_path) + corrected_manifest = correct_manifest(manifest_gen, corrections_path) + new_data_manifest_path = data_manifest_path.with_name("manifest.new") + asr_manifest_writer(new_data_manifest_path, corrected_manifest) + new_data_manifest_path.replace(data_manifest_path) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data_utils/validation/ui.py b/jasper/data_utils/validation/ui.py index 8e61eaf..44cf319 100644 --- a/jasper/data_utils/validation/ui.py +++ b/jasper/data_utils/validation/ui.py @@ -14,7 +14,6 @@ import typer from .jasper_client import transcriber_pretrained, transcriber_speller from .st_rerun import rerun - app = typer.Typer() st.title("ASR Speller Validation") @@ -53,7 +52,6 @@ if not hasattr(st, "mongo_connected"): {"$set": {"value": value}}, upsert=True, ) - rerun() cursor_obj = mongo_conn.find_one({"type": "current_cursor"}) if not cursor_obj: @@ -76,7 +74,6 @@ def preprocess_datapoint(idx, rel, sample): audio_path = rel / Path(sample["audio_filepath"]) res["audio_path"] = audio_path res["gold_chars"] = audio_path.stem - res["gold_phone"] = sample["text"] aud_seg = ( AudioSegment.from_wav(audio_path) .set_channels(1) @@ -85,7 +82,7 @@ def preprocess_datapoint(idx, rel, sample): ) res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) res["speller_asr"] = transcriber_speller(aud_seg.raw_data) - res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]]) + res["wer"] = word_error_rate([res["text"]], [res["speller_asr"]]) (y, sr) = librosa.load(audio_path) plt.tight_layout() librosa.display.waveplot(y=y, sr=sr) @@ -116,7 +113,7 @@ def main(manifest: Path): sample_no = st.get_current_cursor() sample = pnr_data[sample_no] st.markdown( - f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['gold_phone']}*" + f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['text']}*" ) new_sample = st.number_input( "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data) @@ -125,7 +122,7 @@ def main(manifest: Path): st.update_cursor(new_sample - 1) st.sidebar.title(f"Details: [{sample['real_idx']}]") st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**") - st.sidebar.markdown(f"Expected Speech: *{sample['gold_phone']}*") + st.sidebar.markdown(f"Expected Speech: *{sample['text']}*") st.sidebar.title("Results:") st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}") st.sidebar.text(f"Speller:{sample['speller_asr']}") @@ -158,6 +155,7 @@ def main(manifest: Path): st.markdown( f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' ) + # real_idx = st.text_input("Go to real-index:", value=sample['real_idx']) # st.markdown( # ",".join( # [