jasper-asr/jasper/data/validation/ui.py

159 lines
5.7 KiB
Python

from pathlib import Path
import streamlit as st
import typer
from uuid import uuid4
from ..utils import ExtendedPath, get_mongo_conn
from .st_rerun import rerun
app = typer.Typer()
if not hasattr(st, "mongo_connected"):
st.mongoclient = get_mongo_conn(col="asr_validation")
mongo_conn = st.mongoclient
st.task_id = str(uuid4())
def current_cursor_fn():
# mongo_conn = st.mongoclient
cursor_obj = mongo_conn.find_one(
{"type": "current_cursor", "task_id": st.task_id}
)
cursor_val = cursor_obj["cursor"]
return cursor_val
def update_cursor_fn(val=0):
mongo_conn.find_one_and_update(
{"type": "current_cursor", "task_id": st.task_id},
{"$set": {"type": "current_cursor", "task_id": st.task_id, "cursor": val}},
upsert=True,
)
rerun()
def get_correction_entry_fn(code):
return mongo_conn.find_one(
{"type": "correction", "code": code}, projection={"_id": False}
)
def update_entry_fn(code, value):
mongo_conn.find_one_and_update(
{"type": "correction", "code": code},
{"$set": {"value": value, "task_id": st.task_id}},
upsert=True,
)
def set_task_fn(mf_path):
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
if not task_path.exists():
print(f"creating task lock at {task_path}")
task_path.touch()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn
st.get_correction_entry = get_correction_entry_fn
st.update_entry = update_entry_fn
st.set_task = set_task_fn
st.mongo_connected = True
cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
if not cursor_obj:
update_cursor_fn(0)
@st.cache()
def load_ui_data(validation_ui_data_path: Path):
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
return ExtendedPath(validation_ui_data_path).read_json()
@app.command()
def main(manifest: Path):
st.set_task(manifest)
ui_config = load_ui_data(manifest)
asr_data = ui_config["data"]
use_domain_asr = ui_config.get("use_domain_asr", True)
annotation_only = ui_config.get("annotation_only", False)
enable_plots = ui_config.get("enable_plots", True)
sample_no = st.get_current_cursor()
if len(asr_data) - 1 < sample_no or sample_no < 0:
print("Invalid samplno resetting to 0")
st.update_cursor(0)
sample = asr_data[sample_no]
title_type = "Speller " if use_domain_asr else ""
task_uid = st.task_id.rsplit("-", 1)[1]
if annotation_only:
st.title(f"ASR Annotation - # {task_uid}")
else:
st.title(f"ASR {title_type}Validation - # {task_uid}")
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
if not annotation_only:
if use_domain_asr:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:")
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if "caller" in sample:
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
if enable_plots:
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text
corrected = sample["text"]
correction_entry = st.get_correction_entry(sample["utterance_id"])
selected_idx = 0
options = ("Correct", "Incorrect", "Inaudible")
# if correction entry is present set the corresponding ui defaults
if correction_entry:
selected_idx = options.index(correction_entry["value"]["status"])
corrected = correction_entry["value"]["correction"]
selected = st.radio("The Audio is", options, index=selected_idx)
if selected == "Incorrect":
corrected = st.text_input("Actual:", value=corrected)
if selected == "Inaudible":
corrected = ""
if st.button("Submit"):
st.update_entry(
sample["utterance_id"], {"status": selected, "correction": corrected}
)
st.update_cursor(sample_no + 1)
if correction_entry:
st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
)
text_sample = st.text_input("Go to Text:", value="")
if text_sample != "":
candidates = [
i
for (i, p) in enumerate(asr_data)
if p["text"] == text_sample or p["spoken"] == text_sample
]
if len(candidates) > 0:
st.update_cursor(candidates[0])
real_idx = st.number_input(
"Go to real-index",
value=sample["real_idx"],
min_value=0,
max_value=len(asr_data) - 1,
)
if real_idx != int(sample["real_idx"]):
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
st.update_cursor(idx)
if __name__ == "__main__":
try:
app()
except SystemExit:
pass