parent
42647196fe
commit
e8f58a5043
|
|
@ -0,0 +1,4 @@
|
||||||
|
[flake8]
|
||||||
|
exclude = docs
|
||||||
|
ignore = E203, W503
|
||||||
|
max-line-length = 119
|
||||||
|
|
@ -58,23 +58,10 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||||
return num_datapoints
|
return num_datapoints
|
||||||
|
|
||||||
|
|
||||||
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
def ui_data_generator(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||||
dataset_dir = output_dir / Path(dataset_name)
|
dataset_dir = output_dir / Path(dataset_name)
|
||||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||||
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
|
||||||
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
|
(dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True)
|
||||||
asr_manifest = dataset_dir / Path("manifest.json")
|
|
||||||
num_datapoints = 0
|
|
||||||
ui_dump = {
|
|
||||||
"use_domain_asr": False,
|
|
||||||
"annotation_only": False,
|
|
||||||
"enable_plots": True,
|
|
||||||
"data": [],
|
|
||||||
}
|
|
||||||
data_funcs = []
|
|
||||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
|
||||||
with asr_manifest.open("w") as mf:
|
|
||||||
print(f"writing manifest to {asr_manifest}")
|
|
||||||
|
|
||||||
def data_fn(
|
def data_fn(
|
||||||
transcript,
|
transcript,
|
||||||
|
|
@ -89,9 +76,8 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
||||||
):
|
):
|
||||||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
||||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
||||||
wav_plot_path = (
|
png_path = Path(fname).with_suffix(".png")
|
||||||
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
wav_plot_path = dataset_dir / Path("wav_plots") / png_path
|
||||||
)
|
|
||||||
if not wav_plot_path.exists():
|
if not wav_plot_path.exists():
|
||||||
plot_seg(wav_plot_path, audio_path)
|
plot_seg(wav_plot_path, audio_path)
|
||||||
return {
|
return {
|
||||||
|
|
@ -108,14 +94,15 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
||||||
"plot_path": str(wav_plot_path),
|
"plot_path": str(wav_plot_path),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
num_datapoints = 0
|
||||||
|
data_funcs = []
|
||||||
|
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||||
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source:
|
||||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||||
audio_file.write_bytes(wav_data)
|
audio_file.write_bytes(wav_data)
|
||||||
audio_path = str(audio_file)
|
audio_path = str(audio_file)
|
||||||
rel_data_path = audio_file.relative_to(dataset_dir)
|
rel_data_path = audio_file.relative_to(dataset_dir)
|
||||||
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
|
||||||
mf.write(manifest)
|
|
||||||
data_funcs.append(
|
data_funcs.append(
|
||||||
partial(
|
partial(
|
||||||
data_fn,
|
data_fn,
|
||||||
|
|
@ -131,10 +118,28 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
num_datapoints += 1
|
num_datapoints += 1
|
||||||
dump_data = parallel_apply(lambda x: x(), data_funcs)
|
ui_data = parallel_apply(lambda x: x(), data_funcs)
|
||||||
# dump_data = [x() for x in tqdm(data_funcs)]
|
return ui_data, num_datapoints
|
||||||
ui_dump["data"] = dump_data
|
|
||||||
ExtendedPath(ui_dump_file).write_json(ui_dump)
|
|
||||||
|
def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||||
|
dataset_dir = output_dir / Path(dataset_name)
|
||||||
|
dump_data, num_datapoints = ui_data_generator(
|
||||||
|
output_dir, dataset_name, asr_data_source, verbose=verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
asr_manifest = dataset_dir / Path("manifest.json")
|
||||||
|
with asr_manifest.open("w") as mf:
|
||||||
|
print(f"writing manifest to {asr_manifest}")
|
||||||
|
for d in dump_data:
|
||||||
|
rel_data_path = d["audio_filepath"]
|
||||||
|
audio_dur = d["duration"]
|
||||||
|
transcript = d["text"]
|
||||||
|
manifest = manifest_str(str(rel_data_path), audio_dur, transcript)
|
||||||
|
mf.write(manifest)
|
||||||
|
|
||||||
|
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||||
|
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
|
||||||
return num_datapoints
|
return num_datapoints
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,10 @@
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
alnum_to_asr_tokens,
|
|
||||||
ExtendedPath,
|
ExtendedPath,
|
||||||
asr_manifest_reader,
|
asr_manifest_reader,
|
||||||
asr_manifest_writer,
|
asr_manifest_writer,
|
||||||
|
|
@ -19,9 +16,7 @@ from ..utils import (
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
def preprocess_datapoint(
|
def preprocess_datapoint(idx, rel_root, sample):
|
||||||
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
|
|
||||||
):
|
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
from nemo.collections.asr.metrics import word_error_rate
|
from nemo.collections.asr.metrics import word_error_rate
|
||||||
from jasper.client import transcribe_gen
|
from jasper.client import transcribe_gen
|
||||||
|
|
@ -31,12 +26,7 @@ def preprocess_datapoint(
|
||||||
res["real_idx"] = idx
|
res["real_idx"] = idx
|
||||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
audio_path = rel_root / Path(sample["audio_filepath"])
|
||||||
res["audio_path"] = str(audio_path)
|
res["audio_path"] = str(audio_path)
|
||||||
if use_domain_asr:
|
|
||||||
res["spoken"] = alnum_to_asr_tokens(res["text"])
|
|
||||||
else:
|
|
||||||
res["spoken"] = res["text"]
|
|
||||||
res["utterance_id"] = audio_path.stem
|
res["utterance_id"] = audio_path.stem
|
||||||
if not annotation_only:
|
|
||||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||||
|
|
||||||
aud_seg = (
|
aud_seg = (
|
||||||
|
|
@ -46,16 +36,7 @@ def preprocess_datapoint(
|
||||||
.set_frame_rate(24000)
|
.set_frame_rate(24000)
|
||||||
)
|
)
|
||||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||||
res["pretrained_wer"] = word_error_rate(
|
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
|
||||||
[res["text"]], [res["pretrained_asr"]]
|
|
||||||
)
|
|
||||||
if use_domain_asr:
|
|
||||||
transcriber_speller = transcribe_gen(asr_port=8045)
|
|
||||||
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
|
||||||
res["domain_wer"] = word_error_rate(
|
|
||||||
[res["spoken"]], [res["pretrained_asr"]]
|
|
||||||
)
|
|
||||||
if enable_plots:
|
|
||||||
wav_plot_path = (
|
wav_plot_path = (
|
||||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||||
)
|
)
|
||||||
|
|
@ -73,61 +54,50 @@ def dump_ui(
|
||||||
dataset_dir: Path = Path("./data/asr_data"),
|
dataset_dir: Path = Path("./data/asr_data"),
|
||||||
dump_dir: Path = Path("./data/valiation_data"),
|
dump_dir: Path = Path("./data/valiation_data"),
|
||||||
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
|
dump_fname: Path = typer.Option(Path("ui_dump.json"), show_default=True),
|
||||||
use_domain_asr: bool = False,
|
|
||||||
annotation_only: bool = False,
|
|
||||||
enable_plots: bool = True,
|
|
||||||
):
|
):
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from io import BytesIO
|
||||||
from functools import partial
|
from pydub import AudioSegment
|
||||||
|
from ..utils import ui_data_generator
|
||||||
|
|
||||||
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
|
data_manifest_path = dataset_dir / Path(data_name) / Path("manifest.json")
|
||||||
dump_path: Path = dump_dir / Path(data_name) / dump_fname
|
|
||||||
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
||||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||||
|
|
||||||
|
def asr_data_source_gen():
|
||||||
with data_manifest_path.open("r") as pf:
|
with data_manifest_path.open("r") as pf:
|
||||||
data_jsonl = pf.readlines()
|
data_jsonl = pf.readlines()
|
||||||
data_funcs = [
|
for v in data_jsonl:
|
||||||
partial(
|
sample = json.loads(v)
|
||||||
preprocess_datapoint,
|
rel_root = data_manifest_path.parent
|
||||||
i,
|
res = dict(sample)
|
||||||
data_manifest_path.parent,
|
audio_path = rel_root / Path(sample["audio_filepath"])
|
||||||
json.loads(v),
|
audio_segment = (
|
||||||
use_domain_asr,
|
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||||
annotation_only,
|
.set_channels(1)
|
||||||
enable_plots,
|
.set_sample_width(2)
|
||||||
|
.set_frame_rate(24000)
|
||||||
)
|
)
|
||||||
for i, v in enumerate(data_jsonl)
|
wav_plot_path = (
|
||||||
]
|
rel_root
|
||||||
|
/ Path("wav_plots")
|
||||||
|
/ Path(audio_path.name).with_suffix(".png")
|
||||||
|
)
|
||||||
|
if not wav_plot_path.exists():
|
||||||
|
plot_seg(wav_plot_path, audio_path)
|
||||||
|
res["plot_path"] = str(wav_plot_path)
|
||||||
|
code_fb = BytesIO()
|
||||||
|
audio_segment.export(code_fb, format="wav")
|
||||||
|
wav_data = code_fb.getvalue()
|
||||||
|
duration = audio_segment.duration_seconds
|
||||||
|
asr_final = res["text"]
|
||||||
|
yield asr_final, duration, wav_data, "caller", audio_segment
|
||||||
|
|
||||||
def exec_func(f):
|
dump_data, num_datapoints = ui_data_generator(
|
||||||
return f()
|
dataset_dir, data_name, asr_data_source_gen()
|
||||||
|
|
||||||
with ThreadPoolExecutor() as exe:
|
|
||||||
print("starting all preprocess tasks")
|
|
||||||
data_final = filter(
|
|
||||||
None,
|
|
||||||
list(
|
|
||||||
tqdm(
|
|
||||||
exe.map(exec_func, data_funcs),
|
|
||||||
position=0,
|
|
||||||
leave=True,
|
|
||||||
total=len(data_funcs),
|
|
||||||
)
|
)
|
||||||
),
|
ui_dump_file = dataset_dir / Path("ui_dump.json")
|
||||||
)
|
ExtendedPath(ui_dump_file).write_json({"data": dump_data})
|
||||||
if annotation_only:
|
|
||||||
result = list(data_final)
|
|
||||||
else:
|
|
||||||
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
|
||||||
result = sorted(data_final, key=lambda x: x[wer_key], reverse=True)
|
|
||||||
ui_config = {
|
|
||||||
"use_domain_asr": use_domain_asr,
|
|
||||||
"annotation_only": annotation_only,
|
|
||||||
"enable_plots": enable_plots,
|
|
||||||
"data": result,
|
|
||||||
}
|
|
||||||
ExtendedPath(dump_path).write_json(ui_config)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
|
@ -190,7 +160,9 @@ def dump_corrections(
|
||||||
col = get_mongo_conn(col="asr_validation")
|
col = get_mongo_conn(col="asr_validation")
|
||||||
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
||||||
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
||||||
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
|
cursor_obj = col.find(
|
||||||
|
{"type": "correction", "task_id": task_id}, projection={"_id": False}
|
||||||
|
)
|
||||||
corrections = [c for c in cursor_obj]
|
corrections = [c for c in cursor_obj]
|
||||||
ExtendedPath(dump_path).write_json(corrections)
|
ExtendedPath(dump_path).write_json(corrections)
|
||||||
|
|
||||||
|
|
@ -264,7 +236,9 @@ def split_extract(
|
||||||
dump_file: Path = Path("ui_dump.json"),
|
dump_file: Path = Path("ui_dump.json"),
|
||||||
manifest_file: Path = Path("manifest.json"),
|
manifest_file: Path = Path("manifest.json"),
|
||||||
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
||||||
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
conv_data_path: Path = typer.Option(
|
||||||
|
Path("./data/conv_data.json"), show_default=True
|
||||||
|
),
|
||||||
extraction_type: str = "all",
|
extraction_type: str = "all",
|
||||||
):
|
):
|
||||||
import shutil
|
import shutil
|
||||||
|
|
@ -286,7 +260,9 @@ def split_extract(
|
||||||
def extract_manifest(mg):
|
def extract_manifest(mg):
|
||||||
for m in mg:
|
for m in mg:
|
||||||
if m["text"] in extraction_vals:
|
if m["text"] in extraction_vals:
|
||||||
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
|
shutil.copy(
|
||||||
|
m["audio_path"], dest_data_dir / Path(m["audio_filepath"])
|
||||||
|
)
|
||||||
yield m
|
yield m
|
||||||
|
|
||||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||||
|
|
@ -295,12 +271,14 @@ def split_extract(
|
||||||
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
||||||
ui_data = orig_ui_data["data"]
|
ui_data = orig_ui_data["data"]
|
||||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||||
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
extracted_ui_data = list(
|
||||||
|
filter(lambda u: u["text"] in extraction_vals, ui_data)
|
||||||
|
)
|
||||||
final_data = []
|
final_data = []
|
||||||
for i, d in enumerate(extracted_ui_data):
|
for i, d in enumerate(extracted_ui_data):
|
||||||
d['real_idx'] = i
|
d["real_idx"] = i
|
||||||
final_data.append(d)
|
final_data.append(d)
|
||||||
orig_ui_data['data'] = final_data
|
orig_ui_data["data"] = final_data
|
||||||
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
||||||
|
|
||||||
if corrections_file:
|
if corrections_file:
|
||||||
|
|
@ -316,7 +294,7 @@ def split_extract(
|
||||||
)
|
)
|
||||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
||||||
|
|
||||||
if extraction_type.value == 'all':
|
if extraction_type.value == "all":
|
||||||
for ext_key in conv_data.keys():
|
for ext_key in conv_data.keys():
|
||||||
extract_data_of_type(ext_key)
|
extract_data_of_type(ext_key)
|
||||||
else:
|
else:
|
||||||
|
|
@ -338,7 +316,7 @@ def update_corrections(
|
||||||
|
|
||||||
def correct_manifest(ui_dump_path, corrections_path):
|
def correct_manifest(ui_dump_path, corrections_path):
|
||||||
corrections = ExtendedPath(corrections_path).read_json()
|
corrections = ExtendedPath(corrections_path).read_json()
|
||||||
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
|
ui_data = ExtendedPath(ui_dump_path).read_json()["data"]
|
||||||
correct_set = {
|
correct_set = {
|
||||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||||
}
|
}
|
||||||
|
|
@ -367,7 +345,9 @@ def update_corrections(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
orig_audio_path = Path(d["audio_path"])
|
orig_audio_path = Path(d["audio_path"])
|
||||||
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
|
new_name = str(
|
||||||
|
Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")
|
||||||
|
)
|
||||||
new_audio_path = orig_audio_path.with_name(new_name)
|
new_audio_path = orig_audio_path.with_name(new_name)
|
||||||
orig_audio_path.replace(new_audio_path)
|
orig_audio_path.replace(new_audio_path)
|
||||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||||
|
|
|
||||||
|
|
@ -72,22 +72,18 @@ def main(manifest: Path, task_id: str = ""):
|
||||||
st.set_task(manifest, task_id)
|
st.set_task(manifest, task_id)
|
||||||
ui_config = load_ui_data(manifest)
|
ui_config = load_ui_data(manifest)
|
||||||
asr_data = ui_config["data"]
|
asr_data = ui_config["data"]
|
||||||
use_domain_asr = ui_config.get("use_domain_asr", True)
|
|
||||||
annotation_only = ui_config.get("annotation_only", False)
|
annotation_only = ui_config.get("annotation_only", False)
|
||||||
enable_plots = ui_config.get("enable_plots", True)
|
|
||||||
sample_no = st.get_current_cursor()
|
sample_no = st.get_current_cursor()
|
||||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||||
print("Invalid samplno resetting to 0")
|
print("Invalid samplno resetting to 0")
|
||||||
st.update_cursor(0)
|
st.update_cursor(0)
|
||||||
sample = asr_data[sample_no]
|
sample = asr_data[sample_no]
|
||||||
title_type = "Speller " if use_domain_asr else ""
|
|
||||||
task_uid = st.task_id.rsplit("-", 1)[1]
|
task_uid = st.task_id.rsplit("-", 1)[1]
|
||||||
if annotation_only:
|
if annotation_only:
|
||||||
st.title(f"ASR Annotation - # {task_uid}")
|
st.title(f"ASR Annotation - # {task_uid}")
|
||||||
else:
|
else:
|
||||||
st.title(f"ASR {title_type}Validation - # {task_uid}")
|
st.title(f"ASR Validation - # {task_uid}")
|
||||||
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**")
|
||||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
|
||||||
new_sample = st.number_input(
|
new_sample = st.number_input(
|
||||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||||
)
|
)
|
||||||
|
|
@ -96,18 +92,12 @@ def main(manifest: Path, task_id: str = ""):
|
||||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||||
if not annotation_only:
|
if not annotation_only:
|
||||||
if use_domain_asr:
|
|
||||||
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
|
||||||
st.sidebar.title("Results:")
|
st.sidebar.title("Results:")
|
||||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||||
if "caller" in sample:
|
if "caller" in sample:
|
||||||
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
|
||||||
if use_domain_asr:
|
|
||||||
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
|
||||||
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
|
||||||
else:
|
else:
|
||||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||||
if enable_plots:
|
|
||||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
st.audio(Path(sample["audio_path"]).open("rb"))
|
||||||
# set default to text
|
# set default to text
|
||||||
|
|
@ -130,16 +120,12 @@ def main(manifest: Path, task_id: str = ""):
|
||||||
)
|
)
|
||||||
st.update_cursor(sample_no + 1)
|
st.update_cursor(sample_no + 1)
|
||||||
if correction_entry:
|
if correction_entry:
|
||||||
st.markdown(
|
status = correction_entry["value"]["status"]
|
||||||
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
correction = correction_entry["value"]["correction"]
|
||||||
)
|
st.markdown(f"Your Response: **{status}** Correction: **{correction}**")
|
||||||
text_sample = st.text_input("Go to Text:", value="")
|
text_sample = st.text_input("Go to Text:", value="")
|
||||||
if text_sample != "":
|
if text_sample != "":
|
||||||
candidates = [
|
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample]
|
||||||
i
|
|
||||||
for (i, p) in enumerate(asr_data)
|
|
||||||
if p["text"] == text_sample or p["spoken"] == text_sample
|
|
||||||
]
|
|
||||||
if len(candidates) > 0:
|
if len(candidates) > 0:
|
||||||
st.update_cursor(candidates[0])
|
st.update_cursor(candidates[0])
|
||||||
real_idx = st.number_input(
|
real_idx = st.number_input(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue