mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-03-09 19:02:35 +00:00
refactored module structure
This commit is contained in:
221
jasper/data/validation/process.py
Normal file
221
jasper/data/validation/process.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import (
|
||||
alnum_to_asr_tokens,
|
||||
ExtendedPath,
|
||||
asr_manifest_reader,
|
||||
asr_manifest_writer,
|
||||
get_mongo_conn,
|
||||
)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
from pydub import AudioSegment
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
from jasper.client import (
|
||||
transcriber_pretrained,
|
||||
transcriber_speller,
|
||||
)
|
||||
|
||||
try:
|
||||
res = dict(sample)
|
||||
res["real_idx"] = idx
|
||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
||||
res["audio_path"] = str(audio_path)
|
||||
res["spoken"] = alnum_to_asr_tokens(res["text"])
|
||||
res["utterance_id"] = audio_path.stem
|
||||
aud_seg = (
|
||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
|
||||
if use_domain_asr:
|
||||
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["domain_wer"] = word_error_rate(
|
||||
[res["spoken"]], [res["pretrained_asr"]]
|
||||
)
|
||||
wav_plot_path = (
|
||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
fig = plt.Figure()
|
||||
ax = fig.add_subplot()
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
||||
with wav_plot_path.open("wb") as wav_plot_f:
|
||||
fig.set_tight_layout(True)
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
# fig.close()
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
return res
|
||||
except BaseException as e:
|
||||
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_validation_ui_data(
|
||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
dump_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
||||
use_domain_asr: bool = True,
|
||||
):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_funcs = [
|
||||
partial(
|
||||
preprocess_datapoint,
|
||||
i,
|
||||
data_manifest_path.parent,
|
||||
json.loads(v),
|
||||
use_domain_asr,
|
||||
)
|
||||
for i, v in enumerate(pnr_jsonl)
|
||||
]
|
||||
|
||||
def exec_func(f):
|
||||
return f()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=20) as exe:
|
||||
print("starting all plot tasks")
|
||||
pnr_data = filter(
|
||||
None,
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, pnr_funcs),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=len(pnr_funcs),
|
||||
)
|
||||
),
|
||||
)
|
||||
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
||||
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
||||
ui_config = {"use_domain_asr": use_domain_asr, "data": result}
|
||||
ExtendedPath(dump_path).write_json(ui_config)
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")):
|
||||
col = get_mongo_conn().test.asr_validation
|
||||
|
||||
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
||||
corrections = [c for c in cursor_obj]
|
||||
ExtendedPath(dump_path).write_json(corrections)
|
||||
|
||||
|
||||
@app.command()
|
||||
def fill_unannotated(
|
||||
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
||||
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
||||
):
|
||||
processed_data = json.load(processed_data_path.open())
|
||||
corrections = json.load(corrections_path.open())
|
||||
annotated_codes = {c["code"] for c in corrections}
|
||||
all_codes = {c["gold_chars"] for c in processed_data}
|
||||
unann_codes = all_codes - annotated_codes
|
||||
mongo_conn = get_mongo_conn().test.asr_validation
|
||||
for c in unann_codes:
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": c},
|
||||
{"$set": {"value": {"status": "Inaudible", "correction": ""}}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def update_corrections(
|
||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
||||
skip_incorrect: bool = True,
|
||||
):
|
||||
def correct_manifest(manifest_data_gen, corrections_path):
|
||||
corrections = json.load(corrections_path.open())
|
||||
correct_set = {
|
||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||
}
|
||||
# incorrect_set = {c["code"] for c in corrections if c["value"]["status"] == "Inaudible"}
|
||||
correction_map = {
|
||||
c["code"]: c["value"]["correction"]
|
||||
for c in corrections
|
||||
if c["value"]["status"] == "Incorrect"
|
||||
}
|
||||
# for d in manifest_data_gen:
|
||||
# if d["chars"] in incorrect_set:
|
||||
# d["audio_path"].unlink()
|
||||
renamed_set = set()
|
||||
for d in manifest_data_gen:
|
||||
if d["chars"] in correct_set:
|
||||
yield {
|
||||
"audio_filepath": d["audio_filepath"],
|
||||
"duration": d["duration"],
|
||||
"text": d["text"],
|
||||
}
|
||||
elif d["chars"] in correction_map:
|
||||
correct_text = correction_map[d["chars"]]
|
||||
if skip_incorrect:
|
||||
print(f'skipping incorrect {d["audio_path"]} corrected to {correct_text}')
|
||||
else:
|
||||
renamed_set.add(correct_text)
|
||||
new_name = str(Path(correct_text).with_suffix(".wav"))
|
||||
d["audio_path"].replace(d["audio_path"].with_name(new_name))
|
||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||
yield {
|
||||
"audio_filepath": new_filepath,
|
||||
"duration": d["duration"],
|
||||
"text": alnum_to_asr_tokens(correct_text),
|
||||
}
|
||||
else:
|
||||
# don't delete if another correction points to an old file
|
||||
if d["chars"] not in renamed_set:
|
||||
d["audio_path"].unlink()
|
||||
else:
|
||||
print(f'skipping deletion of correction:{d["chars"]}')
|
||||
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
dataset_dir = data_manifest_path.parent
|
||||
dataset_name = dataset_dir.name
|
||||
backup_dir = dataset_dir.with_name(dataset_name + ".bkp")
|
||||
if not backup_dir.exists():
|
||||
typer.echo(f"backing up to :{backup_dir}")
|
||||
shutil.copytree(str(dataset_dir), str(backup_dir))
|
||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
|
||||
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
||||
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
||||
new_data_manifest_path.replace(data_manifest_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
def clear_mongo_corrections():
|
||||
delete = typer.confirm("are you sure you want to clear mongo collection it?")
|
||||
if delete:
|
||||
col = get_mongo_conn().test.asr_validation
|
||||
col.delete_many({"type": "correction"})
|
||||
typer.echo("deleted mongo collection.")
|
||||
typer.echo("Aborted")
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
38
jasper/data/validation/st_rerun.py
Normal file
38
jasper/data/validation/st_rerun.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import streamlit.ReportThread as ReportThread
|
||||
from streamlit.ScriptRequestQueue import RerunData
|
||||
from streamlit.ScriptRunner import RerunException
|
||||
from streamlit.server.Server import Server
|
||||
|
||||
|
||||
def rerun():
|
||||
"""Rerun a Streamlit app from the top!"""
|
||||
widget_states = _get_widget_states()
|
||||
raise RerunException(RerunData(widget_states))
|
||||
|
||||
|
||||
def _get_widget_states():
|
||||
# Hack to get the session object from Streamlit.
|
||||
|
||||
ctx = ReportThread.get_report_ctx()
|
||||
|
||||
session = None
|
||||
|
||||
current_server = Server.get_current()
|
||||
if hasattr(current_server, '_session_infos'):
|
||||
# Streamlit < 0.56
|
||||
session_infos = Server.get_current()._session_infos.values()
|
||||
else:
|
||||
session_infos = Server.get_current()._session_info_by_id.values()
|
||||
|
||||
for session_info in session_infos:
|
||||
if session_info.session.enqueue == ctx.enqueue:
|
||||
session = session_info.session
|
||||
|
||||
if session is None:
|
||||
raise RuntimeError(
|
||||
"Oh noes. Couldn't get your Streamlit Session object"
|
||||
"Are you doing something fancy with threads?"
|
||||
)
|
||||
# Got the session object!
|
||||
|
||||
return session._widget_states
|
||||
140
jasper/data/validation/ui.py
Normal file
140
jasper/data/validation/ui.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
import typer
|
||||
from ..utils import ExtendedPath, get_mongo_conn
|
||||
from .st_rerun import rerun
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
if not hasattr(st, "mongo_connected"):
|
||||
st.mongoclient = get_mongo_conn().test.asr_validation
|
||||
mongo_conn = st.mongoclient
|
||||
|
||||
def current_cursor_fn():
|
||||
# mongo_conn = st.mongoclient
|
||||
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
|
||||
cursor_val = cursor_obj["cursor"]
|
||||
return cursor_val
|
||||
|
||||
def update_cursor_fn(val=0):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "current_cursor"},
|
||||
{"$set": {"type": "current_cursor", "cursor": val}},
|
||||
upsert=True,
|
||||
)
|
||||
rerun()
|
||||
|
||||
def get_correction_entry_fn(code):
|
||||
# mongo_conn = st.mongoclient
|
||||
# cursor_obj = mongo_conn.find_one({"type": "correction", "code": code})
|
||||
# cursor_val = cursor_obj["cursor"]
|
||||
return mongo_conn.find_one(
|
||||
{"type": "correction", "code": code}, projection={"_id": False}
|
||||
)
|
||||
|
||||
def update_entry_fn(code, value):
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": code},
|
||||
{"$set": {"value": value}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
|
||||
if not cursor_obj:
|
||||
update_cursor_fn(0)
|
||||
st.get_current_cursor = current_cursor_fn
|
||||
st.update_cursor = update_cursor_fn
|
||||
st.get_correction_entry = get_correction_entry_fn
|
||||
st.update_entry = update_entry_fn
|
||||
st.mongo_connected = True
|
||||
|
||||
|
||||
@st.cache()
|
||||
def load_ui_data(validation_ui_data_path: Path):
|
||||
typer.echo(f"Using validation ui data from {validation_ui_data_path}")
|
||||
return ExtendedPath(validation_ui_data_path).read_json()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(manifest: Path):
|
||||
ui_config = load_ui_data(manifest)
|
||||
asr_data = ui_config["data"]
|
||||
use_domain_asr = ui_config["use_domain_asr"]
|
||||
sample_no = st.get_current_cursor()
|
||||
if len(asr_data) - 1 < sample_no or sample_no < 0:
|
||||
print("Invalid samplno resetting to 0")
|
||||
st.update_cursor(0)
|
||||
sample = asr_data[sample_no]
|
||||
title_type = "Speller " if use_domain_asr else ""
|
||||
st.title(f"ASR {title_type}Validation")
|
||||
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
||||
st.sidebar.title("Results:")
|
||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
||||
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
||||
else:
|
||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
||||
# set default to text
|
||||
corrected = sample["text"]
|
||||
correction_entry = st.get_correction_entry(sample["utterance_id"])
|
||||
selected_idx = 0
|
||||
options = ("Correct", "Incorrect", "Inaudible")
|
||||
# if correction entry is present set the corresponding ui defaults
|
||||
if correction_entry:
|
||||
selected_idx = options.index(correction_entry["value"]["status"])
|
||||
corrected = correction_entry["value"]["correction"]
|
||||
selected = st.radio("The Audio is", options, index=selected_idx)
|
||||
if selected == "Incorrect":
|
||||
corrected = st.text_input("Actual:", value=corrected)
|
||||
if selected == "Inaudible":
|
||||
corrected = ""
|
||||
if st.button("Submit"):
|
||||
correct_code = corrected.replace(" ", "").upper()
|
||||
st.update_entry(
|
||||
sample["utterance_id"], {"status": selected, "correction": correct_code}
|
||||
)
|
||||
st.update_cursor(sample_no + 1)
|
||||
if correction_entry:
|
||||
st.markdown(
|
||||
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
||||
)
|
||||
# if st.button("Previous Untagged"):
|
||||
# pass
|
||||
# if st.button("Next Untagged"):
|
||||
# pass
|
||||
text_sample = st.text_input("Go to Text:", value='')
|
||||
if text_sample != '':
|
||||
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample]
|
||||
if len(candidates) > 0:
|
||||
st.update_cursor(candidates[0])
|
||||
real_idx = st.number_input(
|
||||
"Go to real-index",
|
||||
value=sample["real_idx"],
|
||||
min_value=0,
|
||||
max_value=len(asr_data) - 1,
|
||||
)
|
||||
if real_idx != int(sample["real_idx"]):
|
||||
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
|
||||
st.update_cursor(idx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
app()
|
||||
except SystemExit:
|
||||
pass
|
||||
Reference in New Issue
Block a user