1. added a tool to extract asr data from gcp transcripts logs

2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file
3. clean-up leaderboard duration logic
4. added a wip dataloader service
5. made the asr_data_writer util more generic with verbose flags and unique filename
6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node
7. refactored the validation post processing to dump a ui config for validation
8. included utility functions to correct, fill update and clear annotations from mongodb data
9. refactored the ui logic to be more generic for any asr data
10. updated setup.py dependencies to support the above features
Malar Kannan 2020-05-12 23:38:06 +05:30
parent a7da729c0b
commit c06a0814b9
7 changed files with 365 additions and 191 deletions

View File

@ -0,0 +1,93 @@
import typer
from itertools import chain
from io import BytesIO
from pathlib import Path
app = typer.Typer()
@app.command()
def extract_data(
call_audio_dir: Path = Path("/dataset/png_prod/call_audio"),
call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"),
output_dir: Path = Path("./data"),
dataset_name: str = "png_gcp_2jan",
verbose: bool = False,
):
from pydub import AudioSegment
from .utils import ExtendedPath, asr_data_writer
from lenses import lens
call_asr_data: Path = output_dir / Path("asr_data")
call_asr_data.mkdir(exist_ok=True, parents=True)
def wav_event_generator(call_audio_dir):
for wav_path in call_audio_dir.glob("**/*.wav"):
if verbose:
typer.echo(f"loading events for file {wav_path}")
call_wav = AudioSegment.from_file_using_temporary_files(wav_path)
rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir)
meta_path = call_meta_dir / rel_meta_path
events = ExtendedPath(meta_path).read_json()
yield call_wav, wav_path, events
def contains_asr(x):
return "AsrResult" in x
def channel(n):
def filter_func(ev):
return (
ev["AsrResult"]["Channel"] == n
if "Channel" in ev["AsrResult"]
else n == 0
)
return filter_func
def compute_endtime(call_wav, state):
for (i, st) in enumerate(state):
start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0)
transcript = st["AsrResult"]["Alternatives"][0]["Transcript"]
if i + 1 < len(state):
end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"]
else:
end_time = call_wav.duration_seconds
code_seg = call_wav[start_time * 1000 : end_time * 1000]
code_fb = BytesIO()
code_seg.export(code_fb, format="wav")
code_wav = code_fb.getvalue()
# only of some audio data is present yield it
if code_seg.duration_seconds >= 0.5:
yield transcript, code_seg.duration_seconds, code_wav
def asr_data_generator(call_wav, call_wav_fname, events):
call_wav_0, call_wav_1 = call_wav.split_to_mono()
asr_events = lens["Events"].Each()["Event"].Filter(contains_asr)
call_evs_0 = asr_events.Filter(channel(0)).collect()(events)
call_evs_1 = asr_events.Filter(channel(1)).collect()(events)
if verbose:
typer.echo(f"processing data points on {call_wav_fname}")
call_data_0 = compute_endtime(call_wav_0, call_evs_0)
call_data_1 = compute_endtime(call_wav_1, call_evs_1)
return chain(call_data_0, call_data_1)
def generate_call_asr_data():
full_asr_data = []
total_duration = 0
for wav, wav_path, ev in wav_event_generator(call_audio_dir):
asr_data = asr_data_generator(wav, wav_path, ev)
total_duration += wav.duration_seconds
full_asr_data.append(asr_data)
typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s")
n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data))
typer.echo(f"written {n_dps} data points")
generate_call_asr_data()
def main():
app()
if __name__ == "__main__":
main()

View File

@ -11,6 +11,29 @@ app = typer.Typer()
# app.add_typer(plot_app, name="plot") # app.add_typer(plot_app, name="plot")
@app.command()
def export_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")):
from pymongo import MongoClient
from collections import defaultdict
from ruamel.yaml import YAML
yaml = YAML()
mongo_collection = MongoClient("mongodb://localhost:27017/").test.calls
caller_calls = defaultdict(lambda: [])
for call in mongo_collection.find():
sysid = call["SystemID"]
call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}"
caller = call["Caller"]
caller_calls[caller].append(call_uri)
caller_list = []
for caller in caller_calls:
caller_list.append({"name": caller, "calls": caller_calls[caller]})
output_yaml = {"users": caller_list}
typer.echo("exporting call logs to yaml file")
with call_logs_file.open("w") as yf:
yaml.dump(output_yaml, yf)
@app.command() @app.command()
def analyze( def analyze(
leaderboard: bool = False, leaderboard: bool = False,
@ -19,8 +42,6 @@ def analyze(
call_logs_file: Path = Path("./call_logs.yaml"), call_logs_file: Path = Path("./call_logs.yaml"),
output_dir: Path = Path("./data"), output_dir: Path = Path("./data"),
): ):
call_logs_file = Path("./call_logs.yaml")
output_dir = Path("./data")
from urllib.parse import urlsplit from urllib.parse import urlsplit
from functools import reduce from functools import reduce
@ -35,7 +56,6 @@ def analyze(
from datetime import timedelta from datetime import timedelta
# from concurrent.futures import ThreadPoolExecutor # from concurrent.futures import ThreadPoolExecutor
from dateutil.relativedelta import relativedelta
import librosa import librosa
import librosa.display import librosa.display
from lenses import lens from lenses import lens
@ -46,6 +66,8 @@ def analyze(
from tqdm import tqdm from tqdm import tqdm
from .utils import asr_data_writer from .utils import asr_data_writer
from pydub import AudioSegment from pydub import AudioSegment
from natural.date import compress
# from itertools import product, chain # from itertools import product, chain
matplotlib.rcParams["agg.path.chunksize"] = 10000 matplotlib.rcParams["agg.path.chunksize"] = 10000
@ -256,8 +278,11 @@ def analyze(
code_fb = BytesIO() code_fb = BytesIO()
code_seg.export(code_fb, format="wav") code_seg.export(code_fb, format="wav")
code_wav = code_fb.getvalue() code_wav = code_fb.getvalue()
# import pdb; pdb.set_trace() # search for actual pnr code and handle plain codes as well
yield code, code_seg.duration_seconds, code_wav extracted_code = (
re.search(r"'(.*)'", code).groups(0)[0] if len(code) > 6 else code
)
yield extracted_code, code_seg.duration_seconds, code_wav
call_lens = lens["users"].Each()["calls"].Each() call_lens = lens["users"].Each()["calls"].Each()
call_stats = call_lens.modify(retrieve_callmeta)(call_logs) call_stats = call_lens.modify(retrieve_callmeta)(call_logs)
@ -275,22 +300,17 @@ def analyze(
asr_data_writer(call_asr_data, "call_alphanum", data_source()) asr_data_writer(call_asr_data, "call_alphanum", data_source())
# @leader_app.command()
def show_leaderboard(): def show_leaderboard():
def compute_user_stats(call_stat): def compute_user_stats(call_stat):
n_samples = ( n_samples = (
lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat) lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat)
) )
n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat) n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat)
rel_dur = relativedelta(
seconds=int(n_duration.total_seconds()),
microseconds=n_duration.microseconds,
)
return { return {
"num_samples": n_samples, "num_samples": n_samples,
"duration": n_duration.total_seconds(), "duration": n_duration.total_seconds(),
"samples_rate": n_samples / n_duration.total_seconds(), "samples_rate": n_samples / n_duration.total_seconds(),
"duration_str": f"{rel_dur.minutes} mins {rel_dur.seconds} secs", "duration_str": compress(n_duration, pad=" "),
"name": call_stat["name"], "name": call_stat["name"],
} }
@ -313,8 +333,8 @@ def analyze(
} }
)[["Rank", "Name", "Codes", "Duration"]] )[["Rank", "Name", "Codes", "Duration"]]
print( print(
"""Today's ASR Speller Dataset Leaderboard: """ASR Speller Dataset Leaderboard :
----------------------------------------""" ---------------------------------"""
) )
print(leader_board.to_string(index=False)) print(leader_board.to_string(index=False))

View File

@ -0,0 +1,29 @@
import typer
import rpyc
import os
from pathlib import Path
from rpyc.utils.server import ThreadedServer
app = typer.Typer()
class ASRDataService(rpyc.Service):
def get_data_loader(self, data_manifest: Path):
return "hello"
@app.command()
def run_server(port: int = 0):
listen_port = port if port else int(os.environ.get("ASR_RPYC_PORT", "8044"))
service = ASRDataService()
t = ThreadedServer(service, port=listen_port)
typer.echo(f"starting asr server on {listen_port}...")
t.start()
def main():
app()
if __name__ == "__main__":
main()

View File

@ -1,8 +1,13 @@
import numpy as np import numpy as np
import wave import wave
import io import io
import os
import json import json
from pathlib import Path from pathlib import Path
import pymongo
from slugify import slugify
from uuid import uuid4
from num2words import num2words from num2words import num2words
@ -46,42 +51,65 @@ def alnum_to_asr_tokens(text):
return ("".join(num_tokens)).lower() return ("".join(num_tokens)).lower()
def asr_data_writer(output_dir, dataset_name, asr_data_source): def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name) dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
asr_manifest = dataset_dir / Path("manifest.json") asr_manifest = dataset_dir / Path("manifest.json")
num_datapoints = 0
with asr_manifest.open("w") as mf: with asr_manifest.open("w") as mf:
for pnr_code, audio_dur, wav_data in asr_data_source: for transcript, audio_dur, wav_data in asr_data_source:
pnr_af = dataset_dir / Path("wav") / Path(pnr_code).with_suffix(".wav") fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
pnr_af = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
pnr_af.write_bytes(wav_data) pnr_af.write_bytes(wav_data)
rel_pnr_path = pnr_af.relative_to(dataset_dir) rel_pnr_path = pnr_af.relative_to(dataset_dir)
manifest = manifest_str( manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript)
str(rel_pnr_path), audio_dur, alnum_to_asr_tokens(pnr_code)
)
mf.write(manifest) mf.write(manifest)
if verbose:
print(f"writing '{transcript}' of duration {audio_dur}")
num_datapoints += 1
return num_datapoints
def asr_manifest_reader(data_manifest_path: Path): def asr_manifest_reader(data_manifest_path: Path):
print(f'reading manifest from {data_manifest_path}') print(f"reading manifest from {data_manifest_path}")
with data_manifest_path.open("r") as pf: with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines() pnr_jsonl = pf.readlines()
pnr_data = [json.loads(v) for v in pnr_jsonl] pnr_data = [json.loads(v) for v in pnr_jsonl]
for p in pnr_data: for p in pnr_data:
p['audio_path'] = data_manifest_path.parent / Path(p['audio_filepath']) p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p['chars'] = Path(p['audio_filepath']).stem p["chars"] = Path(p["audio_filepath"]).stem
yield p yield p
def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source): def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
with asr_manifest_path.open("w") as mf: with asr_manifest_path.open("w") as mf:
print(f'opening {asr_manifest_path} for writing manifest') print(f"opening {asr_manifest_path} for writing manifest")
for mani_dict in manifest_str_source: for mani_dict in manifest_str_source:
manifest = manifest_str( manifest = manifest_str(
mani_dict['audio_filepath'], mani_dict['duration'], mani_dict['text'] mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"]
) )
mf.write(manifest) mf.write(manifest)
class ExtendedPath(type(Path())):
"""docstring for ExtendedPath."""
def read_json(self):
with self.open("r") as jf:
return json.load(jf)
def write_json(self, data):
self.parent.mkdir(parents=True, exist_ok=True)
with self.open("w") as jf:
return json.dump(data, jf, indent=2)
def get_mongo_conn(host=''):
mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost")
mongo_uri = f"mongodb://{mongo_host}:27017/"
return pymongo.MongoClient(mongo_uri)
def main(): def main():
for c in random_pnr_generator(): for c in random_pnr_generator():
print(c) print(c)

View File

@ -1,105 +1,137 @@
import pymongo
import typer
# import matplotlib.pyplot as plt
from pathlib import Path
import json import json
import shutil import shutil
from pathlib import Path
# import pandas as pd import typer
from pydub import AudioSegment
# from .jasper_client import transcriber_pretrained, transcriber_speller
from jasper.data_utils.validation.jasper_client import (
transcriber_pretrained,
transcriber_speller,
)
from jasper.data_utils.utils import alnum_to_asr_tokens
# import importlib
# import jasper.data_utils.utils
# importlib.reload(jasper.data_utils.utils)
from jasper.data_utils.utils import asr_manifest_reader, asr_manifest_writer
from nemo.collections.asr.metrics import word_error_rate
# from tqdm import tqdm as tqdm_base
from tqdm import tqdm from tqdm import tqdm
from ..utils import (
alnum_to_asr_tokens,
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
get_mongo_conn,
)
app = typer.Typer() app = typer.Typer()
@app.command() def preprocess_datapoint(idx, rel_root, sample, use_domain_asr):
def dump_corrections(dump_path: Path = Path("./data/corrections.json")): import matplotlib.pyplot as plt
col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation import librosa
import librosa.display
from pydub import AudioSegment
from nemo.collections.asr.metrics import word_error_rate
from jasper.data_utils.validation.jasper_client import (
transcriber_pretrained,
transcriber_speller,
)
cursor_obj = col.find({"type": "correction"}, projection={"_id": False}) try:
corrections = [c for c in cursor_obj]
dump_f = dump_path.open("w")
json.dump(corrections, dump_f, indent=2)
dump_f.close()
def preprocess_datapoint(idx, rel, sample):
res = dict(sample) res = dict(sample)
res["real_idx"] = idx res["real_idx"] = idx
audio_path = rel / Path(sample["audio_filepath"]) audio_path = rel_root / Path(sample["audio_filepath"])
res["audio_path"] = str(audio_path) res["audio_path"] = str(audio_path)
res["gold_chars"] = audio_path.stem res["spoken"] = alnum_to_asr_tokens(res["text"])
res["gold_phone"] = sample["text"] res["utterance_id"] = audio_path.stem
aud_seg = ( aud_seg = (
AudioSegment.from_wav(audio_path) AudioSegment.from_file_using_temporary_files(audio_path)
.set_channels(1) .set_channels(1)
.set_sample_width(2) .set_sample_width(2)
.set_frame_rate(24000) .set_frame_rate(24000)
) )
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["speller_asr"] = transcriber_speller(aud_seg.raw_data) res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]])
res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]]) if use_domain_asr:
res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
res["domain_wer"] = word_error_rate(
[res["spoken"]], [res["pretrained_asr"]]
)
wav_plot_path = (
rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png")
)
if not wav_plot_path.exists():
fig = plt.Figure()
ax = fig.add_subplot()
(y, sr) = librosa.load(audio_path)
librosa.display.waveplot(y=y, sr=sr, ax=ax)
with wav_plot_path.open("wb") as wav_plot_f:
fig.set_tight_layout(True)
fig.savefig(wav_plot_f, format="png", dpi=50)
# fig.close()
res["plot_path"] = str(wav_plot_path)
return res return res
except BaseException as e:
print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
def load_dataset(data_manifest_path: Path):
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True))
]
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
return result
@app.command() @app.command()
def dump_processed_data( def dump_validation_ui_data(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
dump_path: Path = Path("./data/processed_data.json"), dump_path: Path = Path("./data/valiation_data/ui_dump.json"),
use_domain_asr: bool = True,
): ):
from concurrent.futures import ThreadPoolExecutor
from functools import partial
plot_dir = data_manifest_path.parent / Path("wav_plots")
plot_dir.mkdir(parents=True, exist_ok=True)
typer.echo(f"Using data manifest:{data_manifest_path}") typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf: with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines() pnr_jsonl = pf.readlines()
pnr_data = [ pnr_funcs = [
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v)) partial(
for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True)) preprocess_datapoint,
i,
data_manifest_path.parent,
json.loads(v),
use_domain_asr,
)
for i, v in enumerate(pnr_jsonl)
] ]
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
dump_path = Path("./data/processed_data.json") def exec_func(f):
dump_f = dump_path.open("w") return f()
json.dump(result, dump_f, indent=2)
dump_f.close() with ThreadPoolExecutor(max_workers=20) as exe:
print("starting all plot tasks")
pnr_data = filter(
None,
list(
tqdm(
exe.map(exec_func, pnr_funcs),
position=0,
leave=True,
total=len(pnr_funcs),
)
),
)
wer_key = "domain_wer" if use_domain_asr else "pretrained_wer"
result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True)
ui_config = {"use_domain_asr": use_domain_asr, "data": result}
ExtendedPath(dump_path).write_json(ui_config)
@app.command()
def dump_corrections(dump_path: Path = Path("./data/corrections.json")):
col = get_mongo_conn().test.asr_validation
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@app.command() @app.command()
def fill_unannotated( def fill_unannotated(
processed_data_path: Path = Path("./data/processed_data.json"), processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
corrections_path: Path = Path("./data/corrections.json"), corrections_path: Path = Path("./data/valiation_data/corrections.json"),
): ):
processed_data = json.load(processed_data_path.open()) processed_data = json.load(processed_data_path.open())
corrections = json.load(corrections_path.open()) corrections = json.load(corrections_path.open())
annotated_codes = {c["code"] for c in corrections} annotated_codes = {c["code"] for c in corrections}
all_codes = {c["gold_chars"] for c in processed_data} all_codes = {c["gold_chars"] for c in processed_data}
unann_codes = all_codes - annotated_codes unann_codes = all_codes - annotated_codes
mongo_conn = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation mongo_conn = get_mongo_conn().test.asr_validation
for c in unann_codes: for c in unann_codes:
mongo_conn.find_one_and_update( mongo_conn.find_one_and_update(
{"type": "correction", "code": c}, {"type": "correction", "code": c},
@ -111,8 +143,8 @@ def fill_unannotated(
@app.command() @app.command()
def update_corrections( def update_corrections(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
processed_data_path: Path = Path("./data/processed_data.json"), processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
corrections_path: Path = Path("./data/corrections.json"), corrections_path: Path = Path("./data/valiation_data/corrections.json"),
): ):
def correct_manifest(manifest_data_gen, corrections_path): def correct_manifest(manifest_data_gen, corrections_path):
corrections = json.load(corrections_path.open()) corrections = json.load(corrections_path.open())
@ -168,6 +200,12 @@ def update_corrections(
new_data_manifest_path.replace(data_manifest_path) new_data_manifest_path.replace(data_manifest_path)
@app.command()
def clear_mongo_corrections():
col = get_mongo_conn().test.asr_validation
col.delete_many({"type": "correction"})
def main(): def main():
app() app()

View File

@ -1,27 +1,15 @@
import json
from io import BytesIO
from pathlib import Path from pathlib import Path
import streamlit as st import streamlit as st
from nemo.collections.asr.metrics import word_error_rate
import librosa
import librosa.display
import matplotlib.pyplot as plt
from tqdm import tqdm
from pydub import AudioSegment
import pymongo
import typer import typer
from .jasper_client import transcriber_pretrained, transcriber_speller from ..utils import ExtendedPath, get_mongo_conn
from .st_rerun import rerun from .st_rerun import rerun
app = typer.Typer() app = typer.Typer()
st.title("ASR Speller Validation")
if not hasattr(st, "mongo_connected"): if not hasattr(st, "mongo_connected"):
st.mongoclient = pymongo.MongoClient( st.mongoclient = get_mongo_conn().test.asr_validation
"mongodb://localhost:27017/"
).test.asr_validation
mongo_conn = st.mongoclient mongo_conn = st.mongoclient
def current_cursor_fn(): def current_cursor_fn():
@ -63,80 +51,49 @@ if not hasattr(st, "mongo_connected"):
st.mongo_connected = True st.mongo_connected = True
# def clear_mongo_corrections(): @st.cache()
# col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation def load_ui_data(validation_ui_data_path: Path):
# col.delete_many({"type": "correction"}) typer.echo(f"Using validation ui data from :{validation_ui_data_path}")
return ExtendedPath(validation_ui_data_path).read_json()
def preprocess_datapoint(idx, rel, sample):
res = dict(sample)
res["real_idx"] = idx
audio_path = rel / Path(sample["audio_filepath"])
res["audio_path"] = audio_path
res["gold_chars"] = audio_path.stem
aud_seg = (
AudioSegment.from_wav(audio_path)
.set_channels(1)
.set_sample_width(2)
.set_frame_rate(24000)
)
res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
res["speller_asr"] = transcriber_speller(aud_seg.raw_data)
res["wer"] = word_error_rate([res["text"]], [res["speller_asr"]])
(y, sr) = librosa.load(audio_path)
plt.tight_layout()
librosa.display.waveplot(y=y, sr=sr)
wav_plot_f = BytesIO()
plt.savefig(wav_plot_f, format="png", dpi=50)
plt.close()
wav_plot_f.seek(0)
res["plot_png"] = wav_plot_f
return res
@st.cache(hash_funcs={"rpyc.core.netref.builtins.method": lambda _: None})
def preprocess_dataset(data_manifest_path: Path):
typer.echo(f"Using data manifest:{data_manifest_path}")
with data_manifest_path.open("r") as pf:
pnr_jsonl = pf.readlines()
pnr_data = [
preprocess_datapoint(i, data_manifest_path.parent, json.loads(v))
for i, v in enumerate(tqdm(pnr_jsonl))
]
result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True)
return result
@app.command() @app.command()
def main(manifest: Path): def main(manifest: Path):
pnr_data = preprocess_dataset(manifest) ui_config = load_ui_data(manifest)
asr_data = ui_config["data"]
use_domain_asr = ui_config["use_domain_asr"]
sample_no = st.get_current_cursor() sample_no = st.get_current_cursor()
sample = pnr_data[sample_no] sample = asr_data[sample_no]
st.markdown( title_type = 'Speller ' if use_domain_asr else ''
f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['text']}*" st.title(f"ASR {title_type}Validation")
addl_text = (
f"spelled *{sample['spoken']}*" if use_domain_asr else ""
) )
st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text)
new_sample = st.number_input( new_sample = st.number_input(
"Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data) "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data)
) )
if new_sample != sample_no + 1: if new_sample != sample_no + 1:
st.update_cursor(new_sample - 1) st.update_cursor(new_sample - 1)
st.sidebar.title(f"Details: [{sample['real_idx']}]") st.sidebar.title(f"Details: [{sample['real_idx']}]")
st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**") st.sidebar.markdown(f"Gold Text: **{sample['text']}**")
st.sidebar.markdown(f"Expected Speech: *{sample['text']}*") if use_domain_asr:
st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*")
st.sidebar.title("Results:") st.sidebar.title("Results:")
st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}") st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**")
st.sidebar.text(f"Speller:{sample['speller_asr']}") if use_domain_asr:
st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**")
st.sidebar.title(f"Speller WER: {sample['wer']:.2f}%") st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")
# (y, sr) = librosa.load(sample["audio_path"]) else:
# librosa.display.waveplot(y=y, sr=sr) st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%")
# st.sidebar.pyplot(fig=sample["plot_fig"]) st.sidebar.image(Path(sample["plot_path"]).read_bytes())
st.sidebar.image(sample["plot_png"]) st.audio(Path(sample["audio_path"]).open("rb"))
st.audio(sample["audio_path"].open("rb")) # set default to text
corrected = sample["gold_chars"] corrected = sample["text"]
correction_entry = st.get_correction_entry(sample["gold_chars"]) correction_entry = st.get_correction_entry(sample["utterance_id"])
selected_idx = 0 selected_idx = 0
options = ("Correct", "Incorrect", "Inaudible") options = ("Correct", "Incorrect", "Inaudible")
# if correction entry is present set the corresponding ui defaults
if correction_entry: if correction_entry:
selected_idx = options.index(correction_entry["value"]["status"]) selected_idx = options.index(correction_entry["value"]["status"])
corrected = correction_entry["value"]["correction"] corrected = correction_entry["value"]["correction"]
@ -148,24 +105,26 @@ def main(manifest: Path):
if st.button("Submit"): if st.button("Submit"):
correct_code = corrected.replace(" ", "").upper() correct_code = corrected.replace(" ", "").upper()
st.update_entry( st.update_entry(
sample["gold_chars"], {"status": selected, "correction": correct_code} sample["utterance_id"], {"status": selected, "correction": correct_code}
) )
st.update_cursor(sample_no + 1) st.update_cursor(sample_no + 1)
if correction_entry: if correction_entry:
st.markdown( st.markdown(
f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**'
) )
# real_idx = st.text_input("Go to real-index:", value=sample['real_idx']) # if st.button("Previous Untagged"):
# st.markdown( # pass
# ",".join( # if st.button("Next Untagged"):
# [ # pass
# "**" + str(p["real_idx"]) + "**" real_idx = st.number_input(
# if p["real_idx"] == sample["real_idx"] "Go to real-index",
# else str(p["real_idx"]) value=sample["real_idx"],
# for p in pnr_data min_value=0,
# ] max_value=len(asr_data) - 1,
# ) )
# ) if real_idx != int(sample["real_idx"]):
idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0]
st.update_cursor(idx)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -22,13 +22,19 @@ extra_requirements = {
"matplotlib==3.2.1", "matplotlib==3.2.1",
"pandas==1.0.3", "pandas==1.0.3",
"tabulate==0.8.7", "tabulate==0.8.7",
"natural==0.2.0",
"num2words==0.5.10",
"typer[all]==0.1.1", "typer[all]==0.1.1",
"python-slugify==4.0.0",
"lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses", "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses",
], ],
"validation": [ "validation": [
"rpyc~=4.1.4", "rpyc~=4.1.4",
"pymongo==3.10.1",
"typer[all]==0.1.1",
"tqdm~=4.39.0", "tqdm~=4.39.0",
"librosa==0.7.2", "librosa==0.7.2",
"matplotlib==3.2.1",
"pydub~=0.23.1", "pydub~=0.23.1",
"streamlit==0.58.0", "streamlit==0.58.0",
"stringcase==1.2.0" "stringcase==1.2.0"
@ -58,6 +64,7 @@ setup(
"jasper_asr_trainer = jasper.train:main", "jasper_asr_trainer = jasper.train:main",
"jasper_asr_data_generate = jasper.data_utils.generator:main", "jasper_asr_data_generate = jasper.data_utils.generator:main",
"jasper_asr_data_recycle = jasper.data_utils.call_recycler:main", "jasper_asr_data_recycle = jasper.data_utils.call_recycler:main",
"jasper_asr_data_validation = jasper.data_utils.validation.process:main",
"jasper_asr_data_preprocess = jasper.data_utils.process:main", "jasper_asr_data_preprocess = jasper.data_utils.process:main",
] ]
}, },