1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-03-09 19:02:35 +00:00

1. added a tool to extract asr data from gcp transcripts logs

2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file
3. clean-up leaderboard duration logic
4. added a wip dataloader service
5. made the asr_data_writer util more generic with verbose flags and unique filename
6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node
7. refactored the validation post processing to dump a ui config for validation
8. included utility functions to correct, fill update and clear annotations from mongodb data
9. refactored the ui logic to be more generic for any asr data
10. updated setup.py dependencies to support the above features
This commit is contained in:
2020-05-12 23:38:06 +05:30
parent a7da729c0b
commit c06a0814b9
7 changed files with 365 additions and 191 deletions

View File

@@ -1,105 +1,137 @@
import pymongo
import typer
# import matplotlib.pyplot as plt
from pathlib import Path
import json
import shutil
from pathlib import Path
# import pandas as pd
from pydub import AudioSegment
# from .jasper_client import transcriber_pretrained, transcriber_speller
from jasper.data_utils.validation.jasper_client import (
transcriber_pretrained,
transcriber_speller,
)
from jasper.data_utils.utils import alnum_to_asr_tokens
# import importlib
# import jasper.data_utils.utils
# importlib.reload(jasper.data_utils.utils)
from jasper.data_utils.utils import asr_manifest_reader, asr_manifest_writer
from nemo.collections.asr.metrics import word_error_rate
# from tqdm import tqdm as tqdm_base
import typer
from tqdm import tqdm
from ..utils import (
alnum_to_asr_tokens,
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
get_mongo_conn,
)
app = typer.Typer()
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
import matplotlib.pyplot as plt
import librosa
import librosa.display
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.data_utils.validation.jasper_client import (
transcriber_pretrained,
transcriber_speller,
)
try:
res = dict(sample)
res["real_idx"] = idx
audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
res["spoken"] = alnum_to_asr_tokens(res["text"])
res["utterance_id"] = audio_path.stem
aud_seg = (
AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
if use_domain_asr:
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]]
)
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
fig = plt.Figure()
ax = fig.add_subplot()
(y, sr) = librosa.load(audio_path)
librosa.display.waveplot(y=y, sr=sr, ax=ax)
with wav_plot_path.open("wb") as wav_plot_f:
fig.set_tight_layout(True)
fig.savefig(wav_plot_f, format="png", dpi=50)
# fig.close()
res["plot_path"] = str(wav_plot_path)
return res
except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
@app.command()
def dump_validation_ui_data(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
dump_path: Path = Path("./data/valiation_data/ui_dump.json"),
use_domain_asr: bool = True,
):
from concurrent.futures import ThreadPoolExecutor
from functools import partial
plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_funcs = [
partial(
preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
)
for i, v in enumerate(pnr_jsonl)
]
def exec_func(f):
return f()
with ThreadPoolExecutor(max_workers=20) as exe:
print("starting all plot tasks")
pnr_data = filter(
None,
list(
tqdm(
exe.map(exec_func, pnr_funcs),
position=0,
leave=True,
total=len(pnr_funcs),
)
),
)
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
ui_config = {"use_domain_asr": use_domain_asr, "data": result}
ExtendedPath(dump_path).write_json(ui_config)
@app.command()
def dump_corrections(dump_path: Path = Path("./data/corrections.json")):
col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
col = get_mongo_conn().test.asr_validation
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
corrections = [c for c in cursor_obj]
dump_f = dump_path.open("w")
json.dump(corrections, dump_f, indent=2)
dump_f.close()
def preprocess_datapoint(idx, rel, sample):
res = dict(sample)
res["real_idx"] = idx
audio_path = rel / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path)
res["gold_chars"] = audio_path.stem
res["gold_phone"] = sample["text"]
aud_seg = (
AudioSegment.from_wav(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]])
return res
def load_dataset(data_manifest_path: Path):
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
]
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
return result
@app.command()
def dump_processed_data(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
dump_path: Path = Path("./data/processed_data.json"),
):
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
]
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
dump_path = Path("./data/processed_data.json")
dump_f = dump_path.open("w")
json.dump(result, dump_f, indent=2)
dump_f.close()
ExtendedPath(dump_path).write_json(corrections)
@app.command()
def fill_unannotated(
processed_data_path: Path = Path("./data/processed_data.json"),
corrections_path: Path = Path("./data/corrections.json"),
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
):
processed_data = json.load(processed_data_path.open())
corrections = json.load(corrections_path.open())
annotated_codes = {c["code"] for c in corrections}
all_codes = {c["gold_chars"] for c in processed_data}
unann_codes = all_codes - annotated_codes
mongo_conn = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
mongo_conn = get_mongo_conn().test.asr_validation
for c in unann_codes:
mongo_conn.find_one_and_update(
{"type": "correction", "code": c},
@@ -111,8 +143,8 @@ def fill_unannotated(
@app.command()
def update_corrections(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
processed_data_path: Path = Path("./data/processed_data.json"),
corrections_path: Path = Path("./data/corrections.json"),
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
corrections_path: Path = Path("./data/valiation_data/corrections.json"),
):
def correct_manifest(manifest_data_gen, corrections_path):
corrections = json.load(corrections_path.open())
@@ -168,6 +200,12 @@ def update_corrections(
new_data_manifest_path.replace(data_manifest_path)
@app.command()
def clear_mongo_corrections():
col = get_mongo_conn().test.asr_validation
col.delete_many({"type": "correction"})
def main():
app()

View File

@@ -1,27 +1,15 @@
import json
from io import BytesIO
from pathlib import Path
import streamlit as st
from nemo.collections.asr.metrics import word_error_rate
import librosa
import librosa.display
import matplotlib.pyplot as plt
from tqdm import tqdm
from pydub import AudioSegment
import pymongo
import typer
from .jasper_client import transcriber_pretrained, transcriber_speller
from ..utils import ExtendedPath, get_mongo_conn
from .st_rerun import rerun
app = typer.Typer()
st.title("ASR Speller Validation")
if not hasattr(st, "mongo_connected"):
st.mongoclient = pymongo.MongoClient(
"mongodb://localhost:27017/"
).test.asr_validation
st.mongoclient = get_mongo_conn().test.asr_validation
mongo_conn = st.mongoclient
def current_cursor_fn():
@@ -63,80 +51,49 @@ if not hasattr(st, "mongo_connected"):
st.mongo_connected = True
# def clear_mongo_corrections():
# col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation
# col.delete_many({"type": "correction"})
def preprocess_datapoint(idx, rel, sample):
res = dict(sample)
res["real_idx"] = idx
audio_path = rel / Path(sample["audio_filepath"])
res["audio_path"] = audio_path
res["gold_chars"] = audio_path.stem
aud_seg = (
AudioSegment.from_wav(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
res["wer"] = word_error_rate([res["text"]], [res["speller_asr"]])
(y, sr) = librosa.load(audio_path)
plt.tight_layout()
librosa.display.waveplot(y=y, sr=sr)
wav_plot_f = BytesIO()
plt.savefig(wav_plot_f, format="png", dpi=50)
plt.close()
wav_plot_f.seek(0)
res["plot_png"] = wav_plot_f
return res
@st.cache(hash_funcs={"rpyc.core.netref.builtins.method": lambda _: None})
def preprocess_dataset(data_manifest_path: Path):
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
for i, v in enumerate(tqdm(pnr_jsonl))
]
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
return result
@st.cache()
def load_ui_data(validation_ui_data_path: Path):
typer.echo(f"Using validation ui data from :{validation_ui_data_path}")
return ExtendedPath(validation_ui_data_path).read_json()
@app.command()
def main(manifest: Path):
pnr_data = preprocess_dataset(manifest)
ui_config = load_ui_data(manifest)
asr_data = ui_config["data"]
use_domain_asr = ui_config["use_domain_asr"]
sample_no = st.get_current_cursor()
sample = pnr_data[sample_no]
st.markdown(
f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['text']}*"
sample = asr_data[sample_no]
title_type = 'Speller ' if use_domain_asr else ''
st.title(f"ASR {title_type}Validation")
addl_text = (
f"spelled *{sample['spoken']}*" if use_domain_asr else ""
)
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data)
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
)
if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1)
st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**")
st.sidebar.markdown(f"Expected Speech: *{sample['text']}*")
st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
if use_domain_asr:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:")
st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}")
st.sidebar.text(f"Speller:{sample['speller_asr']}")
st.sidebar.title(f"Speller WER: {sample['wer']:.2f}%")
# (y, sr) = librosa.load(sample["audio_path"])
# librosa.display.waveplot(y=y, sr=sr)
# st.sidebar.pyplot(fig=sample["plot_fig"])
st.sidebar.image(sample["plot_png"])
st.audio(sample["audio_path"].open("rb"))
corrected = sample["gold_chars"]
correction_entry = st.get_correction_entry(sample["gold_chars"])
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
else:
st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.audio(Path(sample["audio_path"]).open("rb"))
# set default to text
corrected = sample["text"]
correction_entry = st.get_correction_entry(sample["utterance_id"])
selected_idx = 0
options = ("Correct", "Incorrect", "Inaudible")
# if correction entry is present set the corresponding ui defaults
if correction_entry:
selected_idx = options.index(correction_entry["value"]["status"])
corrected = correction_entry["value"]["correction"]
@@ -148,24 +105,26 @@ def main(manifest: Path):
if st.button("Submit"):
correct_code = corrected.replace(" ", "").upper()
st.update_entry(
sample["gold_chars"], {"status": selected, "correction": correct_code}
sample["utterance_id"], {"status": selected, "correction": correct_code}
)
st.update_cursor(sample_no + 1)
if correction_entry:
st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
)
# real_idx = st.text_input("Go to real-index:", value=sample['real_idx'])
# st.markdown(
# ",".join(
# [
# "**" + str(p["real_idx"]) + "**"
# if p["real_idx"] == sample["real_idx"]
# else str(p["real_idx"])
# for p in pnr_data
# ]
# )
# )
# if st.button("Previous Untagged"):
# pass
# if st.button("Next Untagged"):
# pass
real_idx = st.number_input(
"Go to real-index",
value=sample["real_idx"],
min_value=0,
max_value=len(asr_data) - 1,
)
if real_idx != int(sample["real_idx"]):
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
st.update_cursor(idx)
if __name__ == "__main__":