from pathlib import Path import streamlit as st import typer from ..utils import ExtendedPath, get_mongo_conn from .st_rerun import rerun app = typer.Typer() if not hasattr(st, "mongo_connected"): st.mongoclient = get_mongo_conn().test.asr_validation mongo_conn = st.mongoclient def current_cursor_fn(): # mongo_conn = st.mongoclient cursor_obj = mongo_conn.find_one({"type": "current_cursor"}) cursor_val = cursor_obj["cursor"] return cursor_val def update_cursor_fn(val=0): mongo_conn.find_one_and_update( {"type": "current_cursor"}, {"$set": {"type": "current_cursor", "cursor": val}}, upsert=True, ) rerun() def get_correction_entry_fn(code): # mongo_conn = st.mongoclient # cursor_obj = mongo_conn.find_one({"type": "correction", "code": code}) # cursor_val = cursor_obj["cursor"] return mongo_conn.find_one( {"type": "correction", "code": code}, projection={"_id": False} ) def update_entry_fn(code, value): mongo_conn.find_one_and_update( {"type": "correction", "code": code}, {"$set": {"value": value}}, upsert=True, ) cursor_obj = mongo_conn.find_one({"type": "current_cursor"}) if not cursor_obj: update_cursor_fn(0) st.get_current_cursor = current_cursor_fn st.update_cursor = update_cursor_fn st.get_correction_entry = get_correction_entry_fn st.update_entry = update_entry_fn st.mongo_connected = True @st.cache() def load_ui_data(validation_ui_data_path: Path): typer.echo(f"Using validation ui data from {validation_ui_data_path}") return ExtendedPath(validation_ui_data_path).read_json() @app.command() def main(manifest: Path): ui_config = load_ui_data(manifest) asr_data = ui_config["data"] use_domain_asr = ui_config.get("use_domain_asr", True) annotation_only = ui_config.get("annotation_only", False) sample_no = st.get_current_cursor() if len(asr_data) - 1 < sample_no or sample_no < 0: print("Invalid samplno resetting to 0") st.update_cursor(0) sample = asr_data[sample_no] title_type = "Speller " if use_domain_asr else "" if annotation_only: st.title(f"ASR Annotation") else: st.title(f"ASR {title_type}Validation") addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else "" st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text) new_sample = st.number_input( "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) ) if new_sample != sample_no + 1: st.update_cursor(new_sample - 1) st.sidebar.title(f"Details: [{sample['real_idx']}]") st.sidebar.markdown(f"Gold Text: **{sample['text']}**") if not annotation_only: if use_domain_asr: st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*") st.sidebar.title("Results:") st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**") if use_domain_asr: st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**") st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%") else: st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") st.sidebar.image(Path(sample["plot_path"]).read_bytes()) st.audio(Path(sample["audio_path"]).open("rb")) # set default to text corrected = sample["text"] correction_entry = st.get_correction_entry(sample["utterance_id"]) selected_idx = 0 options = ("Correct", "Incorrect", "Inaudible") # if correction entry is present set the corresponding ui defaults if correction_entry: selected_idx = options.index(correction_entry["value"]["status"]) corrected = correction_entry["value"]["correction"] selected = st.radio("The Audio is", options, index=selected_idx) if selected == "Incorrect": corrected = st.text_input("Actual:", value=corrected) if selected == "Inaudible": corrected = "" if st.button("Submit"): correct_code = corrected.replace(" ", "").upper() st.update_entry( sample["utterance_id"], {"status": selected, "correction": correct_code} ) st.update_cursor(sample_no + 1) if correction_entry: st.markdown( f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' ) text_sample = st.text_input("Go to Text:", value='') if text_sample != '': candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample] if len(candidates) > 0: st.update_cursor(candidates[0]) real_idx = st.number_input( "Go to real-index", value=sample["real_idx"], min_value=0, max_value=len(asr_data) - 1, ) if real_idx != int(sample["real_idx"]): idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0] st.update_cursor(idx) if __name__ == "__main__": try: app() except SystemExit: pass