mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-03-09 19:02:35 +00:00
1. added a tool to extract asr data from gcp transcripts logs
2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file 3. clean-up leaderboard duration logic 4. added a wip dataloader service 5. made the asr_data_writer util more generic with verbose flags and unique filename 6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node 7. refactored the validation post processing to dump a ui config for validation 8. included utility functions to correct, fill update and clear annotations from mongodb data 9. refactored the ui logic to be more generic for any asr data 10. updated setup.py dependencies to support the above features
This commit is contained in:
@@ -1,105 +1,137 @@
|
||||
import pymongo
|
||||
import typer
|
||||
|
||||
# import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# import pandas as pd
|
||||
from pydub import AudioSegment
|
||||
|
||||
# from .jasper_client import transcriber_pretrained, transcriber_speller
|
||||
from jasper.data_utils.validation.jasper_client import (
|
||||
transcriber_pretrained,
|
||||
transcriber_speller,
|
||||
)
|
||||
from jasper.data_utils.utils import alnum_to_asr_tokens
|
||||
|
||||
# import importlib
|
||||
# import jasper.data_utils.utils
|
||||
# importlib.reload(jasper.data_utils.utils)
|
||||
from jasper.data_utils.utils import asr_manifest_reader, asr_manifest_writer
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
|
||||
# from tqdm import tqdm as tqdm_base
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import (
|
||||
alnum_to_asr_tokens,
|
||||
ExtendedPath,
|
||||
asr_manifest_reader,
|
||||
asr_manifest_writer,
|
||||
get_mongo_conn,
|
||||
)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
from pydub import AudioSegment
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
from jasper.data_utils.validation.jasper_client import (
|
||||
transcriber_pretrained,
|
||||
transcriber_speller,
|
||||
)
|
||||
|
||||
try:
|
||||
res = dict(sample)
|
||||
res["real_idx"] = idx
|
||||
audio_path = rel_root / Path(sample["audio_filepath"])
|
||||
res["audio_path"] = str(audio_path)
|
||||
res["spoken"] = alnum_to_asr_tokens(res["text"])
|
||||
res["utterance_id"] = audio_path.stem
|
||||
aud_seg = (
|
||||
AudioSegment.from_file_using_temporary_files(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
|
||||
if use_domain_asr:
|
||||
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["domain_wer"] = word_error_rate(
|
||||
[res["spoken"]], [res["pretrained_asr"]]
|
||||
)
|
||||
wav_plot_path = (
|
||||
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
|
||||
)
|
||||
if not wav_plot_path.exists():
|
||||
fig = plt.Figure()
|
||||
ax = fig.add_subplot()
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
librosa.display.waveplot(y=y, sr=sr, ax=ax)
|
||||
with wav_plot_path.open("wb") as wav_plot_f:
|
||||
fig.set_tight_layout(True)
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
# fig.close()
|
||||
res["plot_path"] = str(wav_plot_path)
|
||||
return res
|
||||
except BaseException as e:
|
||||
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_validation_ui_data(
|
||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
dump_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
||||
use_domain_asr: bool = True,
|
||||
):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
plot_dir = data_manifest_path.parent / Path("wav_plots")
|
||||
plot_dir.mkdir(parents=True, exist_ok=True)
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_funcs = [
|
||||
partial(
|
||||
preprocess_datapoint,
|
||||
i,
|
||||
data_manifest_path.parent,
|
||||
json.loads(v),
|
||||
use_domain_asr,
|
||||
)
|
||||
for i, v in enumerate(pnr_jsonl)
|
||||
]
|
||||
|
||||
def exec_func(f):
|
||||
return f()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=20) as exe:
|
||||
print("starting all plot tasks")
|
||||
pnr_data = filter(
|
||||
None,
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, pnr_funcs),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=len(pnr_funcs),
|
||||
)
|
||||
),
|
||||
)
|
||||
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
|
||||
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
|
||||
ui_config = {"use_domain_asr": use_domain_asr, "data": result}
|
||||
ExtendedPath(dump_path).write_json(ui_config)
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_corrections(dump_path: Path = Path("./data/corrections.json")):
|
||||
col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
|
||||
col = get_mongo_conn().test.asr_validation
|
||||
|
||||
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
||||
corrections = [c for c in cursor_obj]
|
||||
dump_f = dump_path.open("w")
|
||||
json.dump(corrections, dump_f, indent=2)
|
||||
dump_f.close()
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, rel, sample):
|
||||
res = dict(sample)
|
||||
res["real_idx"] = idx
|
||||
audio_path = rel / Path(sample["audio_filepath"])
|
||||
res["audio_path"] = str(audio_path)
|
||||
res["gold_chars"] = audio_path.stem
|
||||
res["gold_phone"] = sample["text"]
|
||||
aud_seg = (
|
||||
AudioSegment.from_wav(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
|
||||
return res
|
||||
|
||||
|
||||
def load_dataset(data_manifest_path: Path):
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_data = [
|
||||
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
|
||||
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
|
||||
]
|
||||
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
|
||||
return result
|
||||
|
||||
|
||||
@app.command()
|
||||
def dump_processed_data(
|
||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
dump_path: Path = Path("./data/processed_data.json"),
|
||||
):
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_data = [
|
||||
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
|
||||
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
|
||||
]
|
||||
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
|
||||
dump_path = Path("./data/processed_data.json")
|
||||
dump_f = dump_path.open("w")
|
||||
json.dump(result, dump_f, indent=2)
|
||||
dump_f.close()
|
||||
ExtendedPath(dump_path).write_json(corrections)
|
||||
|
||||
|
||||
@app.command()
|
||||
def fill_unannotated(
|
||||
processed_data_path: Path = Path("./data/processed_data.json"),
|
||||
corrections_path: Path = Path("./data/corrections.json"),
|
||||
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
||||
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
||||
):
|
||||
processed_data = json.load(processed_data_path.open())
|
||||
corrections = json.load(corrections_path.open())
|
||||
annotated_codes = {c["code"] for c in corrections}
|
||||
all_codes = {c["gold_chars"] for c in processed_data}
|
||||
unann_codes = all_codes - annotated_codes
|
||||
mongo_conn = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
|
||||
mongo_conn = get_mongo_conn().test.asr_validation
|
||||
for c in unann_codes:
|
||||
mongo_conn.find_one_and_update(
|
||||
{"type": "correction", "code": c},
|
||||
@@ -111,8 +143,8 @@ def fill_unannotated(
|
||||
@app.command()
|
||||
def update_corrections(
|
||||
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
||||
processed_data_path: Path = Path("./data/processed_data.json"),
|
||||
corrections_path: Path = Path("./data/corrections.json"),
|
||||
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
|
||||
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
||||
):
|
||||
def correct_manifest(manifest_data_gen, corrections_path):
|
||||
corrections = json.load(corrections_path.open())
|
||||
@@ -168,6 +200,12 @@ def update_corrections(
|
||||
new_data_manifest_path.replace(data_manifest_path)
|
||||
|
||||
|
||||
@app.command()
|
||||
def clear_mongo_corrections():
|
||||
col = get_mongo_conn().test.asr_validation
|
||||
col.delete_many({"type": "correction"})
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
@@ -1,27 +1,15 @@
|
||||
import json
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
import librosa
|
||||
import librosa.display
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
from pydub import AudioSegment
|
||||
import pymongo
|
||||
import typer
|
||||
from .jasper_client import transcriber_pretrained, transcriber_speller
|
||||
from ..utils import ExtendedPath, get_mongo_conn
|
||||
from .st_rerun import rerun
|
||||
|
||||
app = typer.Typer()
|
||||
st.title("ASR Speller Validation")
|
||||
|
||||
|
||||
if not hasattr(st, "mongo_connected"):
|
||||
st.mongoclient = pymongo.MongoClient(
|
||||
"mongodb://localhost:27017/"
|
||||
).test.asr_validation
|
||||
st.mongoclient = get_mongo_conn().test.asr_validation
|
||||
mongo_conn = st.mongoclient
|
||||
|
||||
def current_cursor_fn():
|
||||
@@ -63,80 +51,49 @@ if not hasattr(st, "mongo_connected"):
|
||||
st.mongo_connected = True
|
||||
|
||||
|
||||
# def clear_mongo_corrections():
|
||||
# col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
|
||||
# col.delete_many({"type": "correction"})
|
||||
|
||||
|
||||
def preprocess_datapoint(idx, rel, sample):
|
||||
res = dict(sample)
|
||||
res["real_idx"] = idx
|
||||
audio_path = rel / Path(sample["audio_filepath"])
|
||||
res["audio_path"] = audio_path
|
||||
res["gold_chars"] = audio_path.stem
|
||||
aud_seg = (
|
||||
AudioSegment.from_wav(audio_path)
|
||||
.set_channels(1)
|
||||
.set_sample_width(2)
|
||||
.set_frame_rate(24000)
|
||||
)
|
||||
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
|
||||
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
|
||||
res["wer"] = word_error_rate([res["text"]], [res["speller_asr"]])
|
||||
(y, sr) = librosa.load(audio_path)
|
||||
plt.tight_layout()
|
||||
librosa.display.waveplot(y=y, sr=sr)
|
||||
wav_plot_f = BytesIO()
|
||||
plt.savefig(wav_plot_f, format="png", dpi=50)
|
||||
plt.close()
|
||||
wav_plot_f.seek(0)
|
||||
res["plot_png"] = wav_plot_f
|
||||
return res
|
||||
|
||||
|
||||
@st.cache(hash_funcs={"rpyc.core.netref.builtins.method": lambda _: None})
|
||||
def preprocess_dataset(data_manifest_path: Path):
|
||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||
with data_manifest_path.open("r") as pf:
|
||||
pnr_jsonl = pf.readlines()
|
||||
pnr_data = [
|
||||
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
|
||||
for i, v in enumerate(tqdm(pnr_jsonl))
|
||||
]
|
||||
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
|
||||
return result
|
||||
@st.cache()
|
||||
def load_ui_data(validation_ui_data_path: Path):
|
||||
typer.echo(f"Using validation ui data from :{validation_ui_data_path}")
|
||||
return ExtendedPath(validation_ui_data_path).read_json()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(manifest: Path):
|
||||
pnr_data = preprocess_dataset(manifest)
|
||||
ui_config = load_ui_data(manifest)
|
||||
asr_data = ui_config["data"]
|
||||
use_domain_asr = ui_config["use_domain_asr"]
|
||||
sample_no = st.get_current_cursor()
|
||||
sample = pnr_data[sample_no]
|
||||
st.markdown(
|
||||
f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['text']}*"
|
||||
sample = asr_data[sample_no]
|
||||
title_type = 'Speller ' if use_domain_asr else ''
|
||||
st.title(f"ASR {title_type}Validation")
|
||||
addl_text = (
|
||||
f"spelled *{sample['spoken']}*" if use_domain_asr else ""
|
||||
)
|
||||
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
|
||||
new_sample = st.number_input(
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data)
|
||||
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
|
||||
)
|
||||
if new_sample != sample_no + 1:
|
||||
st.update_cursor(new_sample - 1)
|
||||
st.sidebar.title(f"Details: [{sample['real_idx']}]")
|
||||
st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**")
|
||||
st.sidebar.markdown(f"Expected Speech: *{sample['text']}*")
|
||||
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
|
||||
st.sidebar.title("Results:")
|
||||
st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}")
|
||||
st.sidebar.text(f"Speller:{sample['speller_asr']}")
|
||||
|
||||
st.sidebar.title(f"Speller WER: {sample['wer']:.2f}%")
|
||||
# (y, sr) = librosa.load(sample["audio_path"])
|
||||
# librosa.display.waveplot(y=y, sr=sr)
|
||||
# st.sidebar.pyplot(fig=sample["plot_fig"])
|
||||
st.sidebar.image(sample["plot_png"])
|
||||
st.audio(sample["audio_path"].open("rb"))
|
||||
corrected = sample["gold_chars"]
|
||||
correction_entry = st.get_correction_entry(sample["gold_chars"])
|
||||
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
|
||||
if use_domain_asr:
|
||||
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
|
||||
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
|
||||
else:
|
||||
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
|
||||
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
|
||||
st.audio(Path(sample["audio_path"]).open("rb"))
|
||||
# set default to text
|
||||
corrected = sample["text"]
|
||||
correction_entry = st.get_correction_entry(sample["utterance_id"])
|
||||
selected_idx = 0
|
||||
options = ("Correct", "Incorrect", "Inaudible")
|
||||
# if correction entry is present set the corresponding ui defaults
|
||||
if correction_entry:
|
||||
selected_idx = options.index(correction_entry["value"]["status"])
|
||||
corrected = correction_entry["value"]["correction"]
|
||||
@@ -148,24 +105,26 @@ def main(manifest: Path):
|
||||
if st.button("Submit"):
|
||||
correct_code = corrected.replace(" ", "").upper()
|
||||
st.update_entry(
|
||||
sample["gold_chars"], {"status": selected, "correction": correct_code}
|
||||
sample["utterance_id"], {"status": selected, "correction": correct_code}
|
||||
)
|
||||
st.update_cursor(sample_no + 1)
|
||||
if correction_entry:
|
||||
st.markdown(
|
||||
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
|
||||
)
|
||||
# real_idx = st.text_input("Go to real-index:", value=sample['real_idx'])
|
||||
# st.markdown(
|
||||
# ",".join(
|
||||
# [
|
||||
# "**" + str(p["real_idx"]) + "**"
|
||||
# if p["real_idx"] == sample["real_idx"]
|
||||
# else str(p["real_idx"])
|
||||
# for p in pnr_data
|
||||
# ]
|
||||
# )
|
||||
# )
|
||||
# if st.button("Previous Untagged"):
|
||||
# pass
|
||||
# if st.button("Next Untagged"):
|
||||
# pass
|
||||
real_idx = st.number_input(
|
||||
"Go to real-index",
|
||||
value=sample["real_idx"],
|
||||
min_value=0,
|
||||
max_value=len(asr_data) - 1,
|
||||
)
|
||||
if real_idx != int(sample["real_idx"]):
|
||||
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
|
||||
st.update_cursor(idx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user