From c06a0814b97a5ce6cd5ab6eae857129082aa9a10 Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Tue, 12 May 2020 23:38:06 +0530 Subject: [PATCH] 1. added a tool to extract asr data from gcp transcripts logs 2. implement a funciton to export all call logs in a mongodb to a caller-id based yaml file 3. clean-up leaderboard duration logic 4. added a wip dataloader service 5. made the asr_data_writer util more generic with verbose flags and unique filename 6. added extendedpath util class with json support and mongo_conn function to connect to a mongo node 7. refactored the validation post processing to dump a ui config for validation 8. included utility functions to correct, fill update and clear annotations from mongodb data 9. refactored the ui logic to be more generic for any asr data 10. updated setup.py dependencies to support the above features --- jasper/data_utils/asr_recycler.py | 93 +++++++++++ jasper/data_utils/call_recycler.py | 46 ++++-- jasper/data_utils/data_server.py | 29 ++++ jasper/data_utils/utils.py | 50 ++++-- jasper/data_utils/validation/process.py | 200 ++++++++++++++---------- jasper/data_utils/validation/ui.py | 131 ++++++---------- setup.py | 7 + 7 files changed, 365 insertions(+), 191 deletions(-) create mode 100644 jasper/data_utils/asr_recycler.py create mode 100644 jasper/data_utils/data_server.py diff --git a/jasper/data_utils/asr_recycler.py b/jasper/data_utils/asr_recycler.py new file mode 100644 index 0000000..5f9cfc6 --- /dev/null +++ b/jasper/data_utils/asr_recycler.py @@ -0,0 +1,93 @@ +import typer +from itertools import chain +from io import BytesIO +from pathlib import Path + +app = typer.Typer() + + +@app.command() +def extract_data( + call_audio_dir: Path = Path("/dataset/png_prod/call_audio"), + call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"), + output_dir: Path = Path("./data"), + dataset_name: str = "png_gcp_2jan", + verbose: bool = False, +): + from pydub import AudioSegment + from .utils import ExtendedPath, asr_data_writer + from lenses import lens + + call_asr_data: Path = output_dir / Path("asr_data") + call_asr_data.mkdir(exist_ok=True, parents=True) + + def wav_event_generator(call_audio_dir): + for wav_path in call_audio_dir.glob("**/*.wav"): + if verbose: + typer.echo(f"loading events for file {wav_path}") + call_wav = AudioSegment.from_file_using_temporary_files(wav_path) + rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir) + meta_path = call_meta_dir / rel_meta_path + events = ExtendedPath(meta_path).read_json() + yield call_wav, wav_path, events + + def contains_asr(x): + return "AsrResult" in x + + def channel(n): + def filter_func(ev): + return ( + ev["AsrResult"]["Channel"] == n + if "Channel" in ev["AsrResult"] + else n == 0 + ) + + return filter_func + + def compute_endtime(call_wav, state): + for (i, st) in enumerate(state): + start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0) + transcript = st["AsrResult"]["Alternatives"][0]["Transcript"] + if i + 1 < len(state): + end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"] + else: + end_time = call_wav.duration_seconds + code_seg = call_wav[start_time * 1000 : end_time * 1000] + code_fb = BytesIO() + code_seg.export(code_fb, format="wav") + code_wav = code_fb.getvalue() + # only of some audio data is present yield it + if code_seg.duration_seconds >= 0.5: + yield transcript, code_seg.duration_seconds, code_wav + + def asr_data_generator(call_wav, call_wav_fname, events): + call_wav_0, call_wav_1 = call_wav.split_to_mono() + asr_events = lens["Events"].Each()["Event"].Filter(contains_asr) + call_evs_0 = asr_events.Filter(channel(0)).collect()(events) + call_evs_1 = asr_events.Filter(channel(1)).collect()(events) + if verbose: + typer.echo(f"processing data points on {call_wav_fname}") + call_data_0 = compute_endtime(call_wav_0, call_evs_0) + call_data_1 = compute_endtime(call_wav_1, call_evs_1) + return chain(call_data_0, call_data_1) + + def generate_call_asr_data(): + full_asr_data = [] + total_duration = 0 + for wav, wav_path, ev in wav_event_generator(call_audio_dir): + asr_data = asr_data_generator(wav, wav_path, ev) + total_duration += wav.duration_seconds + full_asr_data.append(asr_data) + typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s") + n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data)) + typer.echo(f"written {n_dps} data points") + + generate_call_asr_data() + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data_utils/call_recycler.py b/jasper/data_utils/call_recycler.py index d3904c0..ae966d5 100644 --- a/jasper/data_utils/call_recycler.py +++ b/jasper/data_utils/call_recycler.py @@ -11,6 +11,29 @@ app = typer.Typer() # app.add_typer(plot_app, name="plot") +@app.command() +def export_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")): + from pymongo import MongoClient + from collections import defaultdict + from ruamel.yaml import YAML + + yaml = YAML() + mongo_collection = MongoClient("mongodb://localhost:27017/").test.calls + caller_calls = defaultdict(lambda: []) + for call in mongo_collection.find(): + sysid = call["SystemID"] + call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}" + caller = call["Caller"] + caller_calls[caller].append(call_uri) + caller_list = [] + for caller in caller_calls: + caller_list.append({"name": caller, "calls": caller_calls[caller]}) + output_yaml = {"users": caller_list} + typer.echo("exporting call logs to yaml file") + with call_logs_file.open("w") as yf: + yaml.dump(output_yaml, yf) + + @app.command() def analyze( leaderboard: bool = False, @@ -19,8 +42,6 @@ def analyze( call_logs_file: Path = Path("./call_logs.yaml"), output_dir: Path = Path("./data"), ): - call_logs_file = Path("./call_logs.yaml") - output_dir = Path("./data") from urllib.parse import urlsplit from functools import reduce @@ -35,7 +56,6 @@ def analyze( from datetime import timedelta # from concurrent.futures import ThreadPoolExecutor - from dateutil.relativedelta import relativedelta import librosa import librosa.display from lenses import lens @@ -46,6 +66,8 @@ def analyze( from tqdm import tqdm from .utils import asr_data_writer from pydub import AudioSegment + from natural.date import compress + # from itertools import product, chain matplotlib.rcParams["agg.path.chunksize"] = 10000 @@ -256,8 +278,11 @@ def analyze( code_fb = BytesIO() code_seg.export(code_fb, format="wav") code_wav = code_fb.getvalue() - # import pdb; pdb.set_trace() - yield code, code_seg.duration_seconds, code_wav + # search for actual pnr code and handle plain codes as well + extracted_code = ( + re.search(r"'(.*)'", code).groups(0)[0] if len(code) > 6 else code + ) + yield extracted_code, code_seg.duration_seconds, code_wav call_lens = lens["users"].Each()["calls"].Each() call_stats = call_lens.modify(retrieve_callmeta)(call_logs) @@ -275,22 +300,17 @@ def analyze( asr_data_writer(call_asr_data, "call_alphanum", data_source()) - # @leader_app.command() def show_leaderboard(): def compute_user_stats(call_stat): n_samples = ( lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat) ) n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat) - rel_dur = relativedelta( - seconds=int(n_duration.total_seconds()), - microseconds=n_duration.microseconds, - ) return { "num_samples": n_samples, "duration": n_duration.total_seconds(), "samples_rate": n_samples / n_duration.total_seconds(), - "duration_str": f"{rel_dur.minutes} mins {rel_dur.seconds} secs", + "duration_str": compress(n_duration, pad=" "), "name": call_stat["name"], } @@ -313,8 +333,8 @@ def analyze( } )[["Rank", "Name", "Codes", "Duration"]] print( - """Today's ASR Speller Dataset Leaderboard: -----------------------------------------""" + """ASR Speller Dataset Leaderboard : +---------------------------------""" ) print(leader_board.to_string(index=False)) diff --git a/jasper/data_utils/data_server.py b/jasper/data_utils/data_server.py new file mode 100644 index 0000000..3a912a8 --- /dev/null +++ b/jasper/data_utils/data_server.py @@ -0,0 +1,29 @@ +import typer +import rpyc +import os +from pathlib import Path +from rpyc.utils.server import ThreadedServer + +app = typer.Typer() + + +class ASRDataService(rpyc.Service): + def get_data_loader(self, data_manifest: Path): + return "hello" + + +@app.command() +def run_server(port: int = 0): + listen_port = port if port else int(os.environ.get("ASR_RPYC_PORT", "8044")) + service = ASRDataService() + t = ThreadedServer(service, port=listen_port) + typer.echo(f"starting asr server on {listen_port}...") + t.start() + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data_utils/utils.py b/jasper/data_utils/utils.py index 36893d0..da6ec5e 100644 --- a/jasper/data_utils/utils.py +++ b/jasper/data_utils/utils.py @@ -1,8 +1,13 @@ import numpy as np import wave import io +import os import json from pathlib import Path + +import pymongo +from slugify import slugify +from uuid import uuid4 from num2words import num2words @@ -46,42 +51,65 @@ def alnum_to_asr_tokens(text): return ("".join(num_tokens)).lower() -def asr_data_writer(output_dir, dataset_name, asr_data_source): +def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False): dataset_dir = output_dir / Path(dataset_name) (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) asr_manifest = dataset_dir / Path("manifest.json") + num_datapoints = 0 with asr_manifest.open("w") as mf: - for pnr_code, audio_dur, wav_data in asr_data_source: - pnr_af = dataset_dir / Path("wav") / Path(pnr_code).with_suffix(".wav") + for transcript, audio_dur, wav_data in asr_data_source: + fname = str(uuid4()) + "_" + slugify(transcript, max_length=8) + pnr_af = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") pnr_af.write_bytes(wav_data) rel_pnr_path = pnr_af.relative_to(dataset_dir) - manifest = manifest_str( - str(rel_pnr_path), audio_dur, alnum_to_asr_tokens(pnr_code) - ) + manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript) mf.write(manifest) + if verbose: + print(f"writing '{transcript}' of duration {audio_dur}") + num_datapoints += 1 + return num_datapoints def asr_manifest_reader(data_manifest_path: Path): - print(f'reading manifest from {data_manifest_path}') + print(f"reading manifest from {data_manifest_path}") with data_manifest_path.open("r") as pf: pnr_jsonl = pf.readlines() pnr_data = [json.loads(v) for v in pnr_jsonl] for p in pnr_data: - p['audio_path'] = data_manifest_path.parent / Path(p['audio_filepath']) - p['chars'] = Path(p['audio_filepath']).stem + p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"]) + p["chars"] = Path(p["audio_filepath"]).stem yield p def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source): with asr_manifest_path.open("w") as mf: - print(f'opening {asr_manifest_path} for writing manifest') + print(f"opening {asr_manifest_path} for writing manifest") for mani_dict in manifest_str_source: manifest = manifest_str( - mani_dict['audio_filepath'], mani_dict['duration'], mani_dict['text'] + mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"] ) mf.write(manifest) +class ExtendedPath(type(Path())): + """docstring for ExtendedPath.""" + + def read_json(self): + with self.open("r") as jf: + return json.load(jf) + + def write_json(self, data): + self.parent.mkdir(parents=True, exist_ok=True) + with self.open("w") as jf: + return json.dump(data, jf, indent=2) + + +def get_mongo_conn(host=''): + mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost") + mongo_uri = f"mongodb://{mongo_host}:27017/" + return pymongo.MongoClient(mongo_uri) + + def main(): for c in random_pnr_generator(): print(c) diff --git a/jasper/data_utils/validation/process.py b/jasper/data_utils/validation/process.py index 2718771..a2588aa 100644 --- a/jasper/data_utils/validation/process.py +++ b/jasper/data_utils/validation/process.py @@ -1,105 +1,137 @@ -import pymongo -import typer - -# import matplotlib.pyplot as plt -from pathlib import Path import json import shutil +from pathlib import Path -# import pandas as pd -from pydub import AudioSegment - -# from .jasper_client import transcriber_pretrained, transcriber_speller -from jasper.data_utils.validation.jasper_client import ( - transcriber_pretrained, - transcriber_speller, -) -from jasper.data_utils.utils import alnum_to_asr_tokens - -# import importlib -# import jasper.data_utils.utils -# importlib.reload(jasper.data_utils.utils) -from jasper.data_utils.utils import asr_manifest_reader, asr_manifest_writer -from nemo.collections.asr.metrics import word_error_rate - -# from tqdm import tqdm as tqdm_base +import typer from tqdm import tqdm +from ..utils import ( + alnum_to_asr_tokens, + ExtendedPath, + asr_manifest_reader, + asr_manifest_writer, + get_mongo_conn, +) + app = typer.Typer() +def preprocess_datapoint(idx, rel_root, sample, use_domain_asr): + import matplotlib.pyplot as plt + import librosa + import librosa.display + from pydub import AudioSegment + from nemo.collections.asr.metrics import word_error_rate + from jasper.data_utils.validation.jasper_client import ( + transcriber_pretrained, + transcriber_speller, + ) + + try: + res = dict(sample) + res["real_idx"] = idx + audio_path = rel_root / Path(sample["audio_filepath"]) + res["audio_path"] = str(audio_path) + res["spoken"] = alnum_to_asr_tokens(res["text"]) + res["utterance_id"] = audio_path.stem + aud_seg = ( + AudioSegment.from_file_using_temporary_files(audio_path) + .set_channels(1) + .set_sample_width(2) + .set_frame_rate(24000) + ) + res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) + res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]]) + if use_domain_asr: + res["domain_asr"] = transcriber_speller(aud_seg.raw_data) + res["domain_wer"] = word_error_rate( + [res["spoken"]], [res["pretrained_asr"]] + ) + wav_plot_path = ( + rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png") + ) + if not wav_plot_path.exists(): + fig = plt.Figure() + ax = fig.add_subplot() + (y, sr) = librosa.load(audio_path) + librosa.display.waveplot(y=y, sr=sr, ax=ax) + with wav_plot_path.open("wb") as wav_plot_f: + fig.set_tight_layout(True) + fig.savefig(wav_plot_f, format="png", dpi=50) + # fig.close() + res["plot_path"] = str(wav_plot_path) + return res + except BaseException as e: + print(f'failed on {idx}: {sample["audio_filepath"]} with {e}') + + +@app.command() +def dump_validation_ui_data( + data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), + dump_path: Path = Path("./data/valiation_data/ui_dump.json"), + use_domain_asr: bool = True, +): + from concurrent.futures import ThreadPoolExecutor + from functools import partial + + plot_dir = data_manifest_path.parent / Path("wav_plots") + plot_dir.mkdir(parents=True, exist_ok=True) + typer.echo(f"Using data manifest:{data_manifest_path}") + with data_manifest_path.open("r") as pf: + pnr_jsonl = pf.readlines() + pnr_funcs = [ + partial( + preprocess_datapoint, + i, + data_manifest_path.parent, + json.loads(v), + use_domain_asr, + ) + for i, v in enumerate(pnr_jsonl) + ] + + def exec_func(f): + return f() + + with ThreadPoolExecutor(max_workers=20) as exe: + print("starting all plot tasks") + pnr_data = filter( + None, + list( + tqdm( + exe.map(exec_func, pnr_funcs), + position=0, + leave=True, + total=len(pnr_funcs), + ) + ), + ) + wer_key = "domain_wer" if use_domain_asr else "pretrained_wer" + result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True) + ui_config = {"use_domain_asr": use_domain_asr, "data": result} + ExtendedPath(dump_path).write_json(ui_config) + + @app.command() def dump_corrections(dump_path: Path = Path("./data/corrections.json")): - col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation + col = get_mongo_conn().test.asr_validation cursor_obj = col.find({"type": "correction"}, projection={"_id": False}) corrections = [c for c in cursor_obj] - dump_f = dump_path.open("w") - json.dump(corrections, dump_f, indent=2) - dump_f.close() - - -def preprocess_datapoint(idx, rel, sample): - res = dict(sample) - res["real_idx"] = idx - audio_path = rel / Path(sample["audio_filepath"]) - res["audio_path"] = str(audio_path) - res["gold_chars"] = audio_path.stem - res["gold_phone"] = sample["text"] - aud_seg = ( - AudioSegment.from_wav(audio_path) - .set_channels(1) - .set_sample_width(2) - .set_frame_rate(24000) - ) - res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) - res["speller_asr"] = transcriber_speller(aud_seg.raw_data) - res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]]) - return res - - -def load_dataset(data_manifest_path: Path): - typer.echo(f"Using data manifest:{data_manifest_path}") - with data_manifest_path.open("r") as pf: - pnr_jsonl = pf.readlines() - pnr_data = [ - preprocess_datapoint(i, data_manifest_path.parent, json.loads(v)) - for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True)) - ] - result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True) - return result - - -@app.command() -def dump_processed_data( - data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), - dump_path: Path = Path("./data/processed_data.json"), -): - typer.echo(f"Using data manifest:{data_manifest_path}") - with data_manifest_path.open("r") as pf: - pnr_jsonl = pf.readlines() - pnr_data = [ - preprocess_datapoint(i, data_manifest_path.parent, json.loads(v)) - for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True)) - ] - result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True) - dump_path = Path("./data/processed_data.json") - dump_f = dump_path.open("w") - json.dump(result, dump_f, indent=2) - dump_f.close() + ExtendedPath(dump_path).write_json(corrections) @app.command() def fill_unannotated( - processed_data_path: Path = Path("./data/processed_data.json"), - corrections_path: Path = Path("./data/corrections.json"), + processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"), + corrections_path: Path = Path("./data/valiation_data/corrections.json"), ): processed_data = json.load(processed_data_path.open()) corrections = json.load(corrections_path.open()) annotated_codes = {c["code"] for c in corrections} all_codes = {c["gold_chars"] for c in processed_data} unann_codes = all_codes - annotated_codes - mongo_conn = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation + mongo_conn = get_mongo_conn().test.asr_validation for c in unann_codes: mongo_conn.find_one_and_update( {"type": "correction", "code": c}, @@ -111,8 +143,8 @@ def fill_unannotated( @app.command() def update_corrections( data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), - processed_data_path: Path = Path("./data/processed_data.json"), - corrections_path: Path = Path("./data/corrections.json"), + processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"), + corrections_path: Path = Path("./data/valiation_data/corrections.json"), ): def correct_manifest(manifest_data_gen, corrections_path): corrections = json.load(corrections_path.open()) @@ -168,6 +200,12 @@ def update_corrections( new_data_manifest_path.replace(data_manifest_path) +@app.command() +def clear_mongo_corrections(): + col = get_mongo_conn().test.asr_validation + col.delete_many({"type": "correction"}) + + def main(): app() diff --git a/jasper/data_utils/validation/ui.py b/jasper/data_utils/validation/ui.py index 44cf319..6d495cf 100644 --- a/jasper/data_utils/validation/ui.py +++ b/jasper/data_utils/validation/ui.py @@ -1,27 +1,15 @@ -import json -from io import BytesIO from pathlib import Path import streamlit as st -from nemo.collections.asr.metrics import word_error_rate -import librosa -import librosa.display -import matplotlib.pyplot as plt -from tqdm import tqdm -from pydub import AudioSegment -import pymongo import typer -from .jasper_client import transcriber_pretrained, transcriber_speller +from ..utils import ExtendedPath, get_mongo_conn from .st_rerun import rerun app = typer.Typer() -st.title("ASR Speller Validation") if not hasattr(st, "mongo_connected"): - st.mongoclient = pymongo.MongoClient( - "mongodb://localhost:27017/" - ).test.asr_validation + st.mongoclient = get_mongo_conn().test.asr_validation mongo_conn = st.mongoclient def current_cursor_fn(): @@ -63,80 +51,49 @@ if not hasattr(st, "mongo_connected"): st.mongo_connected = True -# def clear_mongo_corrections(): -# col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation -# col.delete_many({"type": "correction"}) - - -def preprocess_datapoint(idx, rel, sample): - res = dict(sample) - res["real_idx"] = idx - audio_path = rel / Path(sample["audio_filepath"]) - res["audio_path"] = audio_path - res["gold_chars"] = audio_path.stem - aud_seg = ( - AudioSegment.from_wav(audio_path) - .set_channels(1) - .set_sample_width(2) - .set_frame_rate(24000) - ) - res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) - res["speller_asr"] = transcriber_speller(aud_seg.raw_data) - res["wer"] = word_error_rate([res["text"]], [res["speller_asr"]]) - (y, sr) = librosa.load(audio_path) - plt.tight_layout() - librosa.display.waveplot(y=y, sr=sr) - wav_plot_f = BytesIO() - plt.savefig(wav_plot_f, format="png", dpi=50) - plt.close() - wav_plot_f.seek(0) - res["plot_png"] = wav_plot_f - return res - - -@st.cache(hash_funcs={"rpyc.core.netref.builtins.method": lambda _: None}) -def preprocess_dataset(data_manifest_path: Path): - typer.echo(f"Using data manifest:{data_manifest_path}") - with data_manifest_path.open("r") as pf: - pnr_jsonl = pf.readlines() - pnr_data = [ - preprocess_datapoint(i, data_manifest_path.parent, json.loads(v)) - for i, v in enumerate(tqdm(pnr_jsonl)) - ] - result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True) - return result +@st.cache() +def load_ui_data(validation_ui_data_path: Path): + typer.echo(f"Using validation ui data from :{validation_ui_data_path}") + return ExtendedPath(validation_ui_data_path).read_json() @app.command() def main(manifest: Path): - pnr_data = preprocess_dataset(manifest) + ui_config = load_ui_data(manifest) + asr_data = ui_config["data"] + use_domain_asr = ui_config["use_domain_asr"] sample_no = st.get_current_cursor() - sample = pnr_data[sample_no] - st.markdown( - f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['text']}*" + sample = asr_data[sample_no] + title_type = 'Speller ' if use_domain_asr else '' + st.title(f"ASR {title_type}Validation") + addl_text = ( + f"spelled *{sample['spoken']}*" if use_domain_asr else "" ) + st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text) new_sample = st.number_input( - "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data) + "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) ) if new_sample != sample_no + 1: st.update_cursor(new_sample - 1) st.sidebar.title(f"Details: [{sample['real_idx']}]") - st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**") - st.sidebar.markdown(f"Expected Speech: *{sample['text']}*") + st.sidebar.markdown(f"Gold Text: **{sample['text']}**") + if use_domain_asr: + st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*") st.sidebar.title("Results:") - st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}") - st.sidebar.text(f"Speller:{sample['speller_asr']}") - - st.sidebar.title(f"Speller WER: {sample['wer']:.2f}%") - # (y, sr) = librosa.load(sample["audio_path"]) - # librosa.display.waveplot(y=y, sr=sr) - # st.sidebar.pyplot(fig=sample["plot_fig"]) - st.sidebar.image(sample["plot_png"]) - st.audio(sample["audio_path"].open("rb")) - corrected = sample["gold_chars"] - correction_entry = st.get_correction_entry(sample["gold_chars"]) + st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**") + if use_domain_asr: + st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**") + st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%") + else: + st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") + st.sidebar.image(Path(sample["plot_path"]).read_bytes()) + st.audio(Path(sample["audio_path"]).open("rb")) + # set default to text + corrected = sample["text"] + correction_entry = st.get_correction_entry(sample["utterance_id"]) selected_idx = 0 options = ("Correct", "Incorrect", "Inaudible") + # if correction entry is present set the corresponding ui defaults if correction_entry: selected_idx = options.index(correction_entry["value"]["status"]) corrected = correction_entry["value"]["correction"] @@ -148,24 +105,26 @@ def main(manifest: Path): if st.button("Submit"): correct_code = corrected.replace(" ", "").upper() st.update_entry( - sample["gold_chars"], {"status": selected, "correction": correct_code} + sample["utterance_id"], {"status": selected, "correction": correct_code} ) st.update_cursor(sample_no + 1) if correction_entry: st.markdown( f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' ) - # real_idx = st.text_input("Go to real-index:", value=sample['real_idx']) - # st.markdown( - # ",".join( - # [ - # "**" + str(p["real_idx"]) + "**" - # if p["real_idx"] == sample["real_idx"] - # else str(p["real_idx"]) - # for p in pnr_data - # ] - # ) - # ) + # if st.button("Previous Untagged"): + # pass + # if st.button("Next Untagged"): + # pass + real_idx = st.number_input( + "Go to real-index", + value=sample["real_idx"], + min_value=0, + max_value=len(asr_data) - 1, + ) + if real_idx != int(sample["real_idx"]): + idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0] + st.update_cursor(idx) if __name__ == "__main__": diff --git a/setup.py b/setup.py index 7111587..855d133 100644 --- a/setup.py +++ b/setup.py @@ -22,13 +22,19 @@ extra_requirements = { "matplotlib==3.2.1", "pandas==1.0.3", "tabulate==0.8.7", + "natural==0.2.0", + "num2words==0.5.10", "typer[all]==0.1.1", + "python-slugify==4.0.0", "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses", ], "validation": [ "rpyc~=4.1.4", + "pymongo==3.10.1", + "typer[all]==0.1.1", "tqdm~=4.39.0", "librosa==0.7.2", + "matplotlib==3.2.1", "pydub~=0.23.1", "streamlit==0.58.0", "stringcase==1.2.0" @@ -58,6 +64,7 @@ setup( "jasper_asr_trainer = jasper.train:main", "jasper_asr_data_generate = jasper.data_utils.generator:main", "jasper_asr_data_recycle = jasper.data_utils.call_recycler:main", + "jasper_asr_data_validation = jasper.data_utils.validation.process:main", "jasper_asr_data_preprocess = jasper.data_utils.process:main", ] },