mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-03-09 19:02:35 +00:00
1. added support for mono/dual channel rev transcripts
2. handle errors when extracting datapoints from rev meta data 3. added suport for annotation only task when dumping ui data
This commit is contained in:
1
jasper/data/validation/__init__.py
Normal file
1
jasper/data/validation/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
@@ -16,16 +16,12 @@ from ..utils import (
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
|
||||
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr, annotation_only):
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
from pydub import AudioSegment
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
from jasper.client import (
|
||||
transcriber_pretrained,
|
||||
transcriber_speller,
|
||||
)
|
||||
|
||||
try:
|
||||
res = dict(sample)
|
||||
@@ -40,13 +36,18 @@ def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
|
||||
if use_domain_asr:
|
||||
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["domain_wer"] = word_error_rate(
|
||||
[res["spoken"]], [res["pretrained_asr"]]
|
||||
if not annotation_only:
|
||||
from jasper.client import transcriber_pretrained, transcriber_speller
|
||||
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["pretrained_wer"] = word_error_rate(
|
||||
[res["text"]], [res["pretrained_asr"]]
|
||||
)
|
||||
if use_domain_asr:
|
||||
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["domain_wer"] = word_error_rate(
|
||||
[res["spoken"]], [res["pretrained_asr"]]
|
||||
)
|
||||
wav_plot_path = (
|
||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
@@ -67,9 +68,14 @@ def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
|
||||
|
||||
@app.command()
|
||||
def dump_validation_ui_data(
|
||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
dump_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
||||
data_manifest_path: Path = typer.Option(
|
||||
Path("./data/asr_data/call_alphanum/manifest.json"), show_default=True
|
||||
),
|
||||
dump_path: Path = typer.Option(
|
||||
Path("./data/valiation_data/ui_dump.json"), show_default=True
|
||||
),
|
||||
use_domain_asr: bool = True,
|
||||
annotation_only: bool = True,
|
||||
):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
@@ -86,6 +92,7 @@ def dump_validation_ui_data(
|
||||
data_manifest_path.parent,
|
||||
json.loads(v),
|
||||
use_domain_asr,
|
||||
annotation_only,
|
||||
)
|
||||
for i, v in enumerate(pnr_jsonl)
|
||||
]
|
||||
@@ -94,7 +101,7 @@ def dump_validation_ui_data(
|
||||
return f()
|
||||
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all plot tasks")
|
||||
print("starting all preprocess tasks")
|
||||
pnr_data = filter(
|
||||
None,
|
||||
list(
|
||||
@@ -106,9 +113,16 @@ def dump_validation_ui_data(
|
||||
)
|
||||
),
|
||||
)
|
||||
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
||||
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
||||
ui_config = {"use_domain_asr": use_domain_asr, "data": result}
|
||||
if annotation_only:
|
||||
result = pnr_data
|
||||
else:
|
||||
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
||||
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
||||
ui_config = {
|
||||
"use_domain_asr": use_domain_asr,
|
||||
"data": result,
|
||||
"annotation_only": annotation_only,
|
||||
}
|
||||
ExtendedPath(dump_path).write_json(ui_config)
|
||||
|
||||
|
||||
@@ -171,7 +185,9 @@ def update_corrections(
|
||||
elif d["chars"] in correction_map:
|
||||
correct_text = correction_map[d["chars"]]
|
||||
if skip_incorrect:
|
||||
print(f'skipping incorrect {d["audio_path"]} corrected to {correct_text}')
|
||||
print(
|
||||
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
||||
)
|
||||
else:
|
||||
renamed_set.add(correct_text)
|
||||
new_name = str(Path(correct_text).with_suffix(".wav"))
|
||||
|
||||
@@ -61,14 +61,18 @@ def load_ui_data(validation_ui_data_path: Path):
|
||||
def main(manifest: Path):
|
||||
ui_config = load_ui_data(manifest)
|
||||
asr_data = ui_config["data"]
|
||||
use_domain_asr = ui_config["use_domain_asr"]
|
||||
use_domain_asr = ui_config.get("use_domain_asr", True)
|
||||
annotation_only = ui_config.get("annotation_only", False)
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
title_type = "Speller " if use_domain_asr else ""
|
||||
st.title(f"ASR {title_type}Validation")
|
||||
if annotation_only:
|
||||
st.title(f"ASR Annotation")
|
||||
else:
|
||||
st.title(f"ASR {title_type}Validation")
|
||||
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
||||
new_sample = st.number_input(
|
||||
@@ -78,15 +82,16 @@ def main(manifest: Path):
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
||||
st.sidebar.title("Results:")
|
||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
||||
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
||||
else:
|
||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||
if not annotation_only:
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
||||
st.sidebar.title("Results:")
|
||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
||||
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
||||
else:
|
||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
||||
# set default to text
|
||||
@@ -113,10 +118,6 @@ def main(manifest: Path):
|
||||
st.markdown(
|
||||
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
||||
)
|
||||
# if st.button("Previous Untagged"):
|
||||
# pass
|
||||
# if st.button("Next Untagged"):
|
||||
# pass
|
||||
text_sample = st.text_input("Go to Text:", value='')
|
||||
if text_sample != '':
|
||||
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample]
|
||||
|
||||
Reference in New Issue
Block a user