1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-03-09 19:02:35 +00:00

1. added start delay arg in call recycler

2. implement ui_dump/manifest  writer in call_recycler itself
3. refactored call data point plotter
4. added sample-ui task-ui  on the validation process
5. implemented call-quality stats using corrections from mongo
6. support deleting cursors on mongo
7. implement multiple task support on validation ui based on task_id mongo field
This commit is contained in:
2020-06-17 19:11:15 +05:30
parent 7dbb04dcbf
commit 8e238c254e
5 changed files with 280 additions and 48 deletions

View File

@@ -12,6 +12,7 @@ from ..utils import (
asr_manifest_reader,
asr_manifest_writer,
get_mongo_conn,
plot_seg,
)
app = typer.Typer()
@@ -20,9 +21,6 @@ app = typer.Typer()
def preprocess_datapoint(
idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots
):
import matplotlib.pyplot as plt
import librosa
import librosa.display
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.client import transcribe_gen
@@ -61,14 +59,7 @@ def preprocess_datapoint(
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
fig = plt.Figure()
ax = fig.add_subplot()
(y, sr) = librosa.load(audio_path)
librosa.display.waveplot(y=y, sr=sr, ax=ax)
with wav_plot_path.open("wb") as wav_plot_f:
fig.set_tight_layout(True)
fig.savefig(wav_plot_f, format="png", dpi=50)
# fig.close()
plot_seg(wav_plot_path, audio_path)
res["plot_path"] = str(wav_plot_path)
return res
except BaseException as e:
@@ -131,17 +122,66 @@ def dump_ui(
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
ui_config = {
"use_domain_asr": use_domain_asr,
"data": result,
"annotation_only": annotation_only,
"enable_plots": enable_plots,
"data": result,
}
ExtendedPath(dump_path).write_json(ui_config)
@app.command()
def sample_ui(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
sample_count: int = typer.Option(80, show_default=True),
sample_file: Path = Path("sample_dump.json"),
):
import pandas as pd
processed_data_path = dump_dir / Path(data_name) / dump_file
sample_path = dump_dir / Path(data_name) / sample_file
processed_data = ExtendedPath(processed_data_path).read_json()
df = pd.DataFrame(processed_data["data"])
samples_per_caller = sample_count // len(df["caller"].unique())
caller_samples = pd.concat(
[g.sample(samples_per_caller) for (c, g) in df.groupby("caller")]
)
caller_samples = caller_samples.reset_index(drop=True)
caller_samples["real_idx"] = caller_samples.index
sample_data = caller_samples.to_dict("records")
processed_data["data"] = sample_data
typer.echo(f"sampling {sample_count} datapoints")
ExtendedPath(sample_path).write_json(processed_data)
@app.command()
def task_ui(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
task_count: int = typer.Option(4, show_default=True),
task_file: str = "task_dump",
):
import pandas as pd
import numpy as np
processed_data_path = dump_dir / Path(data_name) / dump_file
processed_data = ExtendedPath(processed_data_path).read_json()
df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True)
for t_idx, task_f in enumerate(np.array_split(df, task_count)):
task_f = task_f.reset_index(drop=True)
task_f["real_idx"] = task_f.index
task_data = task_f.to_dict("records")
processed_data["data"] = task_data
task_path = dump_dir / Path(data_name) / Path(task_file + f"-{t_idx}.json")
ExtendedPath(task_path).write_json(processed_data)
@app.command()
def dump_corrections(
data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("corrections.json"),
):
dump_path = dump_dir / Path(data_name) / dump_fname
@@ -152,6 +192,38 @@ def dump_corrections(
ExtendedPath(dump_path).write_json(corrections)
@app.command()
def caller_quality(
data_name: str = typer.Option("call_upwork_train_cnd", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("ui_dump.json"),
correction_fname: Path = Path("corrections.json"),
):
import copy
import pandas as pd
dump_path = dump_dir / Path(data_name) / dump_fname
correction_path = dump_dir / Path(data_name) / correction_fname
dump_data = ExtendedPath(dump_path).read_json()
dump_map = {d["utterance_id"]: d for d in dump_data["data"]}
correction_data = ExtendedPath(correction_path).read_json()
def correction_dp(c):
dp = copy.deepcopy(dump_map[c["code"]])
dp["valid"] = c["value"]["status"] == "Correct"
return dp
corrected_dump = [correction_dp(c) for c in correction_data]
df = pd.DataFrame(corrected_dump)
print(f"Total samples: {len(df)}")
for (c, g) in df.groupby("caller"):
total = len(g)
valid = len(g[g["valid"] == True])
valid_rate = valid * 100 / total
print(f"Caller: {c} Valid%:{valid_rate:.2f} of {total} samples")
@app.command()
def fill_unannotated(
data_name: str = typer.Option("call_alphanum", show_default=True),
@@ -329,7 +401,9 @@ def clear_mongo_corrections():
if delete:
col = get_mongo_conn(col="asr_validation")
col.delete_many({"type": "correction"})
col.delete_many({"type": "current_cursor"})
typer.echo("deleted mongo collection.")
return
typer.echo("Aborted")

View File

@@ -2,6 +2,7 @@ from pathlib import Path
import streamlit as st
import typer
from uuid import uuid4
from ..utils import ExtendedPath, get_mongo_conn
from .st_rerun import rerun
@@ -11,25 +12,25 @@ app = typer.Typer()
if not hasattr(st, "mongo_connected"):
st.mongoclient = get_mongo_conn(col="asr_validation")
mongo_conn = st.mongoclient
st.task_id = str(uuid4())
def current_cursor_fn():
# mongo_conn = st.mongoclient
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
cursor_obj = mongo_conn.find_one(
{"type": "current_cursor", "task_id": st.task_id}
)
cursor_val = cursor_obj["cursor"]
return cursor_val
def update_cursor_fn(val=0):
mongo_conn.find_one_and_update(
{"type": "current_cursor"},
{"$set": {"type": "current_cursor", "cursor": val}},
{"type": "current_cursor", "task_id": st.task_id},
{"$set": {"type": "current_cursor", "task_id": st.task_id, "cursor": val}},
upsert=True,
)
rerun()
def get_correction_entry_fn(code):
# mongo_conn = st.mongoclient
# cursor_obj = mongo_conn.find_one({"type": "correction", "code": code})
# cursor_val = cursor_obj["cursor"]
return mongo_conn.find_one(
{"type": "correction", "code": code}, projection={"_id": False}
)
@@ -37,18 +38,25 @@ if not hasattr(st, "mongo_connected"):
def update_entry_fn(code, value):
mongo_conn.find_one_and_update(
{"type": "correction", "code": code},
{"$set": {"value": value}},
{"$set": {"value": value, "task_id": st.task_id}},
upsert=True,
)
cursor_obj = mongo_conn.find_one({"type": "current_cursor"})
if not cursor_obj:
update_cursor_fn(0)
def set_task_fn(mf_path):
task_path = mf_path.parent / Path(f"task-{st.task_id}.lck")
if not task_path.exists():
print(f"creating task lock at {task_path}")
task_path.touch()
st.get_current_cursor = current_cursor_fn
st.update_cursor = update_cursor_fn
st.get_correction_entry = get_correction_entry_fn
st.update_entry = update_entry_fn
st.set_task = set_task_fn
st.mongo_connected = True
cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id})
if not cursor_obj:
update_cursor_fn(0)
@st.cache()
@@ -59,6 +67,7 @@ def load_ui_data(validation_ui_data_path: Path):
@app.command()
def main(manifest: Path):
st.set_task(manifest)
ui_config = load_ui_data(manifest)
asr_data = ui_config["data"]
use_domain_asr = ui_config.get("use_domain_asr", True)
@@ -70,10 +79,11 @@ def main(manifest: Path):
st.update_cursor(0)
sample = asr_data[sample_no]
title_type = "Speller " if use_domain_asr else ""
task_uid = st.task_id.rsplit("-", 1)[1]
if annotation_only:
st.title(f"ASR Annotation")
st.title(f"ASR Annotation - # {task_uid}")
else:
st.title(f"ASR {title_type}Validation")
st.title(f"ASR {title_type}Validation - # {task_uid}")
addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else ""
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
new_sample = st.number_input(
@@ -88,6 +98,8 @@ def main(manifest: Path):
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:")
st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
if "caller" in sample:
st.sidebar.markdown(f"Caller: **{sample['caller']}**")
if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")