2020-04-29 08:56:11 +00:00
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import streamlit as st
|
2020-04-29 11:52:45 +00:00
|
|
|
import typer
|
2020-05-12 18:08:06 +00:00
|
|
|
from ..utils import ExtendedPath, get_mongo_conn
|
2020-04-29 08:56:11 +00:00
|
|
|
from .st_rerun import rerun
|
|
|
|
|
|
2020-04-29 11:52:45 +00:00
|
|
|
app = typer.Typer()
|
2020-04-29 08:56:11 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(st, "mongo_connected"):
|
2020-05-12 18:08:06 +00:00
|
|
|
st.mongoclient = get_mongo_conn().test.asr_validation
|
2020-04-29 08:56:11 +00:00
|
|
|
mongo_conn = st.mongoclient
|
|
|
|
|
|
|
|
|
|
def current_cursor_fn():
|
|
|
|
|
# mongo_conn = st.mongoclient
|
|
|
|
|
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
|
|
|
|
|
cursor_val = cursor_obj["cursor"]
|
|
|
|
|
return cursor_val
|
|
|
|
|
|
|
|
|
|
def update_cursor_fn(val=0):
|
|
|
|
|
mongo_conn.find_one_and_update(
|
|
|
|
|
{"type": "current_cursor"},
|
|
|
|
|
{"$set": {"type": "current_cursor", "cursor": val}},
|
|
|
|
|
upsert=True,
|
|
|
|
|
)
|
|
|
|
|
rerun()
|
|
|
|
|
|
|
|
|
|
def get_correction_entry_fn(code):
|
|
|
|
|
# mongo_conn = st.mongoclient
|
|
|
|
|
# cursor_obj = mongo_conn.find_one({"type": "correction", "code": code})
|
|
|
|
|
# cursor_val = cursor_obj["cursor"]
|
|
|
|
|
return mongo_conn.find_one(
|
|
|
|
|
{"type": "correction", "code": code}, projection={"_id": False}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def update_entry_fn(code, value):
|
|
|
|
|
mongo_conn.find_one_and_update(
|
|
|
|
|
{"type": "correction", "code": code},
|
|
|
|
|
{"$set": {"value": value}},
|
|
|
|
|
upsert=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
|
|
|
|
|
if not cursor_obj:
|
|
|
|
|
update_cursor_fn(0)
|
|
|
|
|
st.get_current_cursor = current_cursor_fn
|
|
|
|
|
st.update_cursor = update_cursor_fn
|
|
|
|
|
st.get_correction_entry = get_correction_entry_fn
|
|
|
|
|
st.update_entry = update_entry_fn
|
|
|
|
|
st.mongo_connected = True
|
|
|
|
|
|
|
|
|
|
|
2020-05-12 18:08:06 +00:00
|
|
|
@st.cache()
|
|
|
|
|
def load_ui_data(validation_ui_data_path: Path):
|
|
|
|
|
typer.echo(f"Using validation ui data from :{validation_ui_data_path}")
|
|
|
|
|
return ExtendedPath(validation_ui_data_path).read_json()
|
2020-04-29 08:56:11 +00:00
|
|
|
|
|
|
|
|
|
2020-04-29 11:52:45 +00:00
|
|
|
@app.command()
|
|
|
|
|
def main(manifest: Path):
|
2020-05-12 18:08:06 +00:00
|
|
|
ui_config = load_ui_data(manifest)
|
|
|
|
|
asr_data = ui_config["data"]
|
|
|
|
|
use_domain_asr = ui_config["use_domain_asr"]
|
2020-04-29 08:56:11 +00:00
|
|
|
sample_no = st.get_current_cursor()
|
2020-05-13 08:32:46 +00:00
|
|
|
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
|
|
|
|
print("Invalid samplno resetting to 0")
|
|
|
|
|
st.update_cursor(0)
|
2020-05-12 18:08:06 +00:00
|
|
|
sample = asr_data[sample_no]
|
2020-05-13 08:32:46 +00:00
|
|
|
title_type = "Speller " if use_domain_asr else ""
|
2020-05-12 18:08:06 +00:00
|
|
|
st.title(f"ASR {title_type}Validation")
|
2020-05-13 08:32:46 +00:00
|
|
|
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
2020-05-12 18:08:06 +00:00
|
|
|
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
2020-04-29 08:56:11 +00:00
|
|
|
new_sample = st.number_input(
|
2020-05-12 18:08:06 +00:00
|
|
|
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
2020-04-29 08:56:11 +00:00
|
|
|
)
|
|
|
|
|
if new_sample != sample_no + 1:
|
|
|
|
|
st.update_cursor(new_sample - 1)
|
|
|
|
|
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
2020-05-12 18:08:06 +00:00
|
|
|
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
|
|
|
|
if use_domain_asr:
|
|
|
|
|
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
2020-04-29 08:56:11 +00:00
|
|
|
st.sidebar.title("Results:")
|
2020-05-12 18:08:06 +00:00
|
|
|
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
|
|
|
|
if use_domain_asr:
|
|
|
|
|
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
|
|
|
|
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
|
|
|
|
else:
|
|
|
|
|
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
|
|
|
|
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
|
|
|
|
st.audio(Path(sample["audio_path"]).open("rb"))
|
|
|
|
|
# set default to text
|
|
|
|
|
corrected = sample["text"]
|
|
|
|
|
correction_entry = st.get_correction_entry(sample["utterance_id"])
|
2020-04-29 08:56:11 +00:00
|
|
|
selected_idx = 0
|
|
|
|
|
options = ("Correct", "Incorrect", "Inaudible")
|
2020-05-12 18:08:06 +00:00
|
|
|
# if correction entry is present set the corresponding ui defaults
|
2020-04-29 08:56:11 +00:00
|
|
|
if correction_entry:
|
|
|
|
|
selected_idx = options.index(correction_entry["value"]["status"])
|
|
|
|
|
corrected = correction_entry["value"]["correction"]
|
|
|
|
|
selected = st.radio("The Audio is", options, index=selected_idx)
|
|
|
|
|
if selected == "Incorrect":
|
|
|
|
|
corrected = st.text_input("Actual:", value=corrected)
|
|
|
|
|
if selected == "Inaudible":
|
|
|
|
|
corrected = ""
|
|
|
|
|
if st.button("Submit"):
|
|
|
|
|
correct_code = corrected.replace(" ", "").upper()
|
|
|
|
|
st.update_entry(
|
2020-05-12 18:08:06 +00:00
|
|
|
sample["utterance_id"], {"status": selected, "correction": correct_code}
|
2020-04-29 08:56:11 +00:00
|
|
|
)
|
2020-04-29 17:22:46 +00:00
|
|
|
st.update_cursor(sample_no + 1)
|
2020-04-29 08:56:11 +00:00
|
|
|
if correction_entry:
|
|
|
|
|
st.markdown(
|
|
|
|
|
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
|
|
|
|
)
|
2020-05-12 18:08:06 +00:00
|
|
|
# if st.button("Previous Untagged"):
|
|
|
|
|
# pass
|
|
|
|
|
# if st.button("Next Untagged"):
|
|
|
|
|
# pass
|
|
|
|
|
real_idx = st.number_input(
|
|
|
|
|
"Go to real-index",
|
|
|
|
|
value=sample["real_idx"],
|
|
|
|
|
min_value=0,
|
|
|
|
|
max_value=len(asr_data) - 1,
|
|
|
|
|
)
|
|
|
|
|
if real_idx != int(sample["real_idx"]):
|
|
|
|
|
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
|
|
|
|
|
st.update_cursor(idx)
|
2020-04-29 08:56:11 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-04-29 11:52:45 +00:00
|
|
|
try:
|
|
|
|
|
app()
|
|
|
|
|
except SystemExit:
|
|
|
|
|
pass
|