diff --git a/jasper/data_utils/validation/ui.py b/jasper/data_utils/validation/ui.py index 03fae29..23e1d0a 100644 --- a/jasper/data_utils/validation/ui.py +++ b/jasper/data_utils/validation/ui.py @@ -10,44 +10,15 @@ import matplotlib.pyplot as plt from tqdm import tqdm from pydub import AudioSegment import pymongo +import typer from .jasper_client import transcriber_pretrained, transcriber_speller from .st_rerun import rerun + +app = typer.Typer() st.title("ASR Speller Validation") -def clear_mongo_corrections(): - col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation - col.delete_many({"type": "correction"}) - - -def preprocess_datapoint(idx, sample): - res = dict(sample) - res["real_idx"] = idx - audio_path = Path(sample["audio_filepath"]) - res["audio_path"] = audio_path - res["gold_chars"] = audio_path.stem - res["gold_phone"] = sample["text"] - aud_seg = ( - AudioSegment.from_wav(audio_path) - .set_channels(1) - .set_sample_width(2) - .set_frame_rate(24000) - ) - res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) - res["speller_asr"] = transcriber_speller(aud_seg.raw_data) - res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]]) - (y, sr) = librosa.load(audio_path) - plt.tight_layout() - librosa.display.waveplot(y=y, sr=sr) - wav_plot_f = BytesIO() - plt.savefig(wav_plot_f, format="png", dpi=50) - plt.close() - wav_plot_f.seek(0) - res["plot_png"] = wav_plot_f - return res - - if not hasattr(st, "mongo_connected"): st.mongoclient = pymongo.MongoClient( "mongodb://localhost:27017/" @@ -94,23 +65,54 @@ if not hasattr(st, "mongo_connected"): st.mongo_connected = True +def clear_mongo_corrections(): + col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation + col.delete_many({"type": "correction"}) + + +def preprocess_datapoint(idx, rel, sample): + res = dict(sample) + res["real_idx"] = idx + audio_path = rel / Path(sample["audio_filepath"]) + res["audio_path"] = audio_path + res["gold_chars"] = audio_path.stem + res["gold_phone"] = sample["text"] + aud_seg = ( + AudioSegment.from_wav(audio_path) + .set_channels(1) + .set_sample_width(2) + .set_frame_rate(24000) + ) + res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) + res["speller_asr"] = transcriber_speller(aud_seg.raw_data) + res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]]) + (y, sr) = librosa.load(audio_path) + plt.tight_layout() + librosa.display.waveplot(y=y, sr=sr) + wav_plot_f = BytesIO() + plt.savefig(wav_plot_f, format="png", dpi=50) + plt.close() + wav_plot_f.seek(0) + res["plot_png"] = wav_plot_f + return res + + @st.cache(hash_funcs={"rpyc.core.netref.builtins.method": lambda _: None}) -def preprocess_dataset(dataset_path: Path = Path("/dataset/asr_data/call_alphanum_v3")): - print("misssed cache : preprocess_dataset") - dataset_path: Path = Path("/dataset/asr_data/call_alphanum_v3") - manifest_path = dataset_path / Path("test_manifest.json") - with manifest_path.open("r") as pf: +def preprocess_dataset(data_manifest_path: Path): + typer.echo(f"Using data manifest:{data_manifest_path}") + with data_manifest_path.open("r") as pf: pnr_jsonl = pf.readlines() pnr_data = [ - preprocess_datapoint(i, json.loads(v)) + preprocess_datapoint(i, data_manifest_path.parent, json.loads(v)) for i, v in enumerate(tqdm(pnr_jsonl)) ] result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True) return result -def main(): - pnr_data = preprocess_dataset() +@app.command() +def main(manifest: Path): + pnr_data = preprocess_dataset(manifest) sample_no = st.get_current_cursor() sample = pnr_data[sample_no] st.markdown( @@ -128,7 +130,7 @@ def main(): st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}") st.sidebar.text(f"Speller:{sample['speller_asr']}") - st.sidebar.title(f"WER: {sample['wer']:.2f}%") + st.sidebar.title(f"Speller WER: {sample['wer']:.2f}%") # (y, sr) = librosa.load(sample["audio_path"]) # librosa.display.waveplot(y=y, sr=sr) # st.sidebar.pyplot(fig=sample["plot_fig"]) @@ -168,4 +170,7 @@ def main(): if __name__ == "__main__": - main() + try: + app() + except SystemExit: + pass