mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-06-13 12:32:08 +00:00
1. refactored ui_dump
2. added flake8
This commit is contained in:
@@ -72,22 +72,18 @@ def main(manifest: Path, task_id: str = ""):
|
||||
st.set_task(manifest, task_id)
|
||||
ui_config = load_ui_data(manifest)
|
||||
asr_data = ui_config["data"]
|
||||
use_domain_asr = ui_config.get("use_domain_asr", True)
|
||||
annotation_only = ui_config.get("annotation_only", False)
|
||||
enable_plots = ui_config.get("enable_plots", True)
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
title_type = "Speller " if use_domain_asr else ""
|
||||
task_uid = st.task_id.rsplit("-", 1)[1]
|
||||
if annotation_only:
|
||||
st.title(f"ASR Annotation - # {task_uid}")
|
||||
else:
|
||||
st.title(f"ASR {title_type}Validation - # {task_uid}")
|
||||
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
||||
st.title(f"ASR Validation - # {task_uid}")
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
@@ -96,19 +92,13 @@ def main(manifest: Path, task_id: str = ""):
|
||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
if not annotation_only:
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
||||
st.sidebar.title("Results:")
|
||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||
if "caller" in sample:
|
||||
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
||||
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
||||
else:
|
||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||
if enable_plots:
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
||||
# set default to text
|
||||
corrected = sample["text"]
|
||||
@@ -130,16 +120,12 @@ def main(manifest: Path, task_id: str = ""):
|
||||
)
|
||||
st.update_cursor(sample_no + 1)
|
||||
if correction_entry:
|
||||
st.markdown(
|
||||
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
||||
)
|
||||
status = correction_entry["value"]["status"]
|
||||
correction = correction_entry["value"]["correction"]
|
||||
st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
|
||||
text_sample = st.text_input("Go to Text:", value="")
|
||||
if text_sample != "":
|
||||
candidates = [
|
||||
i
|
||||
for (i, p) in enumerate(asr_data)
|
||||
if p["text"] == text_sample or p["spoken"] == text_sample
|
||||
]
|
||||
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
|
||||
if len(candidates) > 0:
|
||||
st.update_cursor(candidates[0])
|
||||
real_idx = st.number_input(
|
||||
|
||||
Reference in New Issue
Block a user