diff --git a/jasper/data_utils/asr_recycler.py b/jasper/data_utils/asr_recycler.py new file mode 100644 index 0000000..5f9cfc6 --- /dev/null +++ b/jasper/data_utils/asr_recycler.py @@ -0,0 +1,93 @@ +import typer +from itertools import chain +from io import BytesIO +from pathlib import Path + +app = typer.Typer() + + +@app.command() +def extract_data( + call_audio_dir: Path = Path("/dataset/png_prod/call_audio"), + call_meta_dir: Path = Path("/dataset/png_prod/call_metadata"), + output_dir: Path = Path("./data"), + dataset_name: str = "png_gcp_2jan", + verbose: bool = False, +): + from pydub import AudioSegment + from .utils import ExtendedPath, asr_data_writer + from lenses import lens + + call_asr_data: Path = output_dir / Path("asr_data") + call_asr_data.mkdir(exist_ok=True, parents=True) + + def wav_event_generator(call_audio_dir): + for wav_path in call_audio_dir.glob("**/*.wav"): + if verbose: + typer.echo(f"loading events for file {wav_path}") + call_wav = AudioSegment.from_file_using_temporary_files(wav_path) + rel_meta_path = wav_path.with_suffix(".json").relative_to(call_audio_dir) + meta_path = call_meta_dir / rel_meta_path + events = ExtendedPath(meta_path).read_json() + yield call_wav, wav_path, events + + def contains_asr(x): + return "AsrResult" in x + + def channel(n): + def filter_func(ev): + return ( + ev["AsrResult"]["Channel"] == n + if "Channel" in ev["AsrResult"] + else n == 0 + ) + + return filter_func + + def compute_endtime(call_wav, state): + for (i, st) in enumerate(state): + start_time = st["AsrResult"]["Alternatives"][0].get("StartTime", 0) + transcript = st["AsrResult"]["Alternatives"][0]["Transcript"] + if i + 1 < len(state): + end_time = state[i + 1]["AsrResult"]["Alternatives"][0]["StartTime"] + else: + end_time = call_wav.duration_seconds + code_seg = call_wav[start_time * 1000 : end_time * 1000] + code_fb = BytesIO() + code_seg.export(code_fb, format="wav") + code_wav = code_fb.getvalue() + # only of some audio data is present yield it + if code_seg.duration_seconds >= 0.5: + yield transcript, code_seg.duration_seconds, code_wav + + def asr_data_generator(call_wav, call_wav_fname, events): + call_wav_0, call_wav_1 = call_wav.split_to_mono() + asr_events = lens["Events"].Each()["Event"].Filter(contains_asr) + call_evs_0 = asr_events.Filter(channel(0)).collect()(events) + call_evs_1 = asr_events.Filter(channel(1)).collect()(events) + if verbose: + typer.echo(f"processing data points on {call_wav_fname}") + call_data_0 = compute_endtime(call_wav_0, call_evs_0) + call_data_1 = compute_endtime(call_wav_1, call_evs_1) + return chain(call_data_0, call_data_1) + + def generate_call_asr_data(): + full_asr_data = [] + total_duration = 0 + for wav, wav_path, ev in wav_event_generator(call_audio_dir): + asr_data = asr_data_generator(wav, wav_path, ev) + total_duration += wav.duration_seconds + full_asr_data.append(asr_data) + typer.echo(f"loaded {len(full_asr_data)} calls of duration {total_duration}s") + n_dps = asr_data_writer(call_asr_data, dataset_name, chain(*full_asr_data)) + typer.echo(f"written {n_dps} data points") + + generate_call_asr_data() + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data_utils/call_recycler.py b/jasper/data_utils/call_recycler.py index d3904c0..ae966d5 100644 --- a/jasper/data_utils/call_recycler.py +++ b/jasper/data_utils/call_recycler.py @@ -11,6 +11,29 @@ app = typer.Typer() # app.add_typer(plot_app, name="plot") +@app.command() +def export_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")): + from pymongo import MongoClient + from collections import defaultdict + from ruamel.yaml import YAML + + yaml = YAML() + mongo_collection = MongoClient("mongodb://localhost:27017/").test.calls + caller_calls = defaultdict(lambda: []) + for call in mongo_collection.find(): + sysid = call["SystemID"] + call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}" + caller = call["Caller"] + caller_calls[caller].append(call_uri) + caller_list = [] + for caller in caller_calls: + caller_list.append({"name": caller, "calls": caller_calls[caller]}) + output_yaml = {"users": caller_list} + typer.echo("exporting call logs to yaml file") + with call_logs_file.open("w") as yf: + yaml.dump(output_yaml, yf) + + @app.command() def analyze( leaderboard: bool = False, @@ -19,8 +42,6 @@ def analyze( call_logs_file: Path = Path("./call_logs.yaml"), output_dir: Path = Path("./data"), ): - call_logs_file = Path("./call_logs.yaml") - output_dir = Path("./data") from urllib.parse import urlsplit from functools import reduce @@ -35,7 +56,6 @@ def analyze( from datetime import timedelta # from concurrent.futures import ThreadPoolExecutor - from dateutil.relativedelta import relativedelta import librosa import librosa.display from lenses import lens @@ -46,6 +66,8 @@ def analyze( from tqdm import tqdm from .utils import asr_data_writer from pydub import AudioSegment + from natural.date import compress + # from itertools import product, chain matplotlib.rcParams["agg.path.chunksize"] = 10000 @@ -256,8 +278,11 @@ def analyze( code_fb = BytesIO() code_seg.export(code_fb, format="wav") code_wav = code_fb.getvalue() - # import pdb; pdb.set_trace() - yield code, code_seg.duration_seconds, code_wav + # search for actual pnr code and handle plain codes as well + extracted_code = ( + re.search(r"'(.*)'", code).groups(0)[0] if len(code) > 6 else code + ) + yield extracted_code, code_seg.duration_seconds, code_wav call_lens = lens["users"].Each()["calls"].Each() call_stats = call_lens.modify(retrieve_callmeta)(call_logs) @@ -275,22 +300,17 @@ def analyze( asr_data_writer(call_asr_data, "call_alphanum", data_source()) - # @leader_app.command() def show_leaderboard(): def compute_user_stats(call_stat): n_samples = ( lens["calls"].Each()["process"]["num_samples"].get_monoid()(call_stat) ) n_duration = lens["calls"].Each()["duration"].get_monoid()(call_stat) - rel_dur = relativedelta( - seconds=int(n_duration.total_seconds()), - microseconds=n_duration.microseconds, - ) return { "num_samples": n_samples, "duration": n_duration.total_seconds(), "samples_rate": n_samples / n_duration.total_seconds(), - "duration_str": f"{rel_dur.minutes} mins {rel_dur.seconds} secs", + "duration_str": compress(n_duration, pad=" "), "name": call_stat["name"], } @@ -313,8 +333,8 @@ def analyze( } )[["Rank", "Name", "Codes", "Duration"]] print( - """Today's ASR Speller Dataset Leaderboard: -----------------------------------------""" + """ASR Speller Dataset Leaderboard : +---------------------------------""" ) print(leader_board.to_string(index=False)) diff --git a/jasper/data_utils/data_server.py b/jasper/data_utils/data_server.py new file mode 100644 index 0000000..3a912a8 --- /dev/null +++ b/jasper/data_utils/data_server.py @@ -0,0 +1,29 @@ +import typer +import rpyc +import os +from pathlib import Path +from rpyc.utils.server import ThreadedServer + +app = typer.Typer() + + +class ASRDataService(rpyc.Service): + def get_data_loader(self, data_manifest: Path): + return "hello" + + +@app.command() +def run_server(port: int = 0): + listen_port = port if port else int(os.environ.get("ASR_RPYC_PORT", "8044")) + service = ASRDataService() + t = ThreadedServer(service, port=listen_port) + typer.echo(f"starting asr server on {listen_port}...") + t.start() + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data_utils/utils.py b/jasper/data_utils/utils.py index 36893d0..da6ec5e 100644 --- a/jasper/data_utils/utils.py +++ b/jasper/data_utils/utils.py @@ -1,8 +1,13 @@ import numpy as np import wave import io +import os import json from pathlib import Path + +import pymongo +from slugify import slugify +from uuid import uuid4 from num2words import num2words @@ -46,42 +51,65 @@ def alnum_to_asr_tokens(text): return ("".join(num_tokens)).lower() -def asr_data_writer(output_dir, dataset_name, asr_data_source): +def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False): dataset_dir = output_dir / Path(dataset_name) (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) asr_manifest = dataset_dir / Path("manifest.json") + num_datapoints = 0 with asr_manifest.open("w") as mf: - for pnr_code, audio_dur, wav_data in asr_data_source: - pnr_af = dataset_dir / Path("wav") / Path(pnr_code).with_suffix(".wav") + for transcript, audio_dur, wav_data in asr_data_source: + fname = str(uuid4()) + "_" + slugify(transcript, max_length=8) + pnr_af = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") pnr_af.write_bytes(wav_data) rel_pnr_path = pnr_af.relative_to(dataset_dir) - manifest = manifest_str( - str(rel_pnr_path), audio_dur, alnum_to_asr_tokens(pnr_code) - ) + manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript) mf.write(manifest) + if verbose: + print(f"writing '{transcript}' of duration {audio_dur}") + num_datapoints += 1 + return num_datapoints def asr_manifest_reader(data_manifest_path: Path): - print(f'reading manifest from {data_manifest_path}') + print(f"reading manifest from {data_manifest_path}") with data_manifest_path.open("r") as pf: pnr_jsonl = pf.readlines() pnr_data = [json.loads(v) for v in pnr_jsonl] for p in pnr_data: - p['audio_path'] = data_manifest_path.parent / Path(p['audio_filepath']) - p['chars'] = Path(p['audio_filepath']).stem + p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"]) + p["chars"] = Path(p["audio_filepath"]).stem yield p def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source): with asr_manifest_path.open("w") as mf: - print(f'opening {asr_manifest_path} for writing manifest') + print(f"opening {asr_manifest_path} for writing manifest") for mani_dict in manifest_str_source: manifest = manifest_str( - mani_dict['audio_filepath'], mani_dict['duration'], mani_dict['text'] + mani_dict["audio_filepath"], mani_dict["duration"], mani_dict["text"] ) mf.write(manifest) +class ExtendedPath(type(Path())): + """docstring for ExtendedPath.""" + + def read_json(self): + with self.open("r") as jf: + return json.load(jf) + + def write_json(self, data): + self.parent.mkdir(parents=True, exist_ok=True) + with self.open("w") as jf: + return json.dump(data, jf, indent=2) + + +def get_mongo_conn(host=''): + mongo_host = host if host else os.environ.get("MONGO_HOST", "localhost") + mongo_uri = f"mongodb://{mongo_host}:27017/" + return pymongo.MongoClient(mongo_uri) + + def main(): for c in random_pnr_generator(): print(c) diff --git a/jasper/data_utils/validation/process.py b/jasper/data_utils/validation/process.py index 2718771..a2588aa 100644 --- a/jasper/data_utils/validation/process.py +++ b/jasper/data_utils/validation/process.py @@ -1,105 +1,137 @@ -import pymongo -import typer - -# import matplotlib.pyplot as plt -from pathlib import Path import json import shutil +from pathlib import Path -# import pandas as pd -from pydub import AudioSegment - -# from .jasper_client import transcriber_pretrained, transcriber_speller -from jasper.data_utils.validation.jasper_client import ( - transcriber_pretrained, - transcriber_speller, -) -from jasper.data_utils.utils import alnum_to_asr_tokens - -# import importlib -# import jasper.data_utils.utils -# importlib.reload(jasper.data_utils.utils) -from jasper.data_utils.utils import asr_manifest_reader, asr_manifest_writer -from nemo.collections.asr.metrics import word_error_rate - -# from tqdm import tqdm as tqdm_base +import typer from tqdm import tqdm +from ..utils import ( + alnum_to_asr_tokens, + ExtendedPath, + asr_manifest_reader, + asr_manifest_writer, + get_mongo_conn, +) + app = typer.Typer() +def preprocess_datapoint(idx, rel_root, sample, use_domain_asr): + import matplotlib.pyplot as plt + import librosa + import librosa.display + from pydub import AudioSegment + from nemo.collections.asr.metrics import word_error_rate + from jasper.data_utils.validation.jasper_client import ( + transcriber_pretrained, + transcriber_speller, + ) + + try: + res = dict(sample) + res["real_idx"] = idx + audio_path = rel_root / Path(sample["audio_filepath"]) + res["audio_path"] = str(audio_path) + res["spoken"] = alnum_to_asr_tokens(res["text"]) + res["utterance_id"] = audio_path.stem + aud_seg = ( + AudioSegment.from_file_using_temporary_files(audio_path) + .set_channels(1) + .set_sample_width(2) + .set_frame_rate(24000) + ) + res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) + res["pretrained_wer"] = word_error_rate([res["text"]], [res["pretrained_asr"]]) + if use_domain_asr: + res["domain_asr"] = transcriber_speller(aud_seg.raw_data) + res["domain_wer"] = word_error_rate( + [res["spoken"]], [res["pretrained_asr"]] + ) + wav_plot_path = ( + rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png") + ) + if not wav_plot_path.exists(): + fig = plt.Figure() + ax = fig.add_subplot() + (y, sr) = librosa.load(audio_path) + librosa.display.waveplot(y=y, sr=sr, ax=ax) + with wav_plot_path.open("wb") as wav_plot_f: + fig.set_tight_layout(True) + fig.savefig(wav_plot_f, format="png", dpi=50) + # fig.close() + res["plot_path"] = str(wav_plot_path) + return res + except BaseException as e: + print(f'failed on {idx}: {sample["audio_filepath"]} with {e}') + + +@app.command() +def dump_validation_ui_data( + data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), + dump_path: Path = Path("./data/valiation_data/ui_dump.json"), + use_domain_asr: bool = True, +): + from concurrent.futures import ThreadPoolExecutor + from functools import partial + + plot_dir = data_manifest_path.parent / Path("wav_plots") + plot_dir.mkdir(parents=True, exist_ok=True) + typer.echo(f"Using data manifest:{data_manifest_path}") + with data_manifest_path.open("r") as pf: + pnr_jsonl = pf.readlines() + pnr_funcs = [ + partial( + preprocess_datapoint, + i, + data_manifest_path.parent, + json.loads(v), + use_domain_asr, + ) + for i, v in enumerate(pnr_jsonl) + ] + + def exec_func(f): + return f() + + with ThreadPoolExecutor(max_workers=20) as exe: + print("starting all plot tasks") + pnr_data = filter( + None, + list( + tqdm( + exe.map(exec_func, pnr_funcs), + position=0, + leave=True, + total=len(pnr_funcs), + ) + ), + ) + wer_key = "domain_wer" if use_domain_asr else "pretrained_wer" + result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True) + ui_config = {"use_domain_asr": use_domain_asr, "data": result} + ExtendedPath(dump_path).write_json(ui_config) + + @app.command() def dump_corrections(dump_path: Path = Path("./data/corrections.json")): - col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation + col = get_mongo_conn().test.asr_validation cursor_obj = col.find({"type": "correction"}, projection={"_id": False}) corrections = [c for c in cursor_obj] - dump_f = dump_path.open("w") - json.dump(corrections, dump_f, indent=2) - dump_f.close() - - -def preprocess_datapoint(idx, rel, sample): - res = dict(sample) - res["real_idx"] = idx - audio_path = rel / Path(sample["audio_filepath"]) - res["audio_path"] = str(audio_path) - res["gold_chars"] = audio_path.stem - res["gold_phone"] = sample["text"] - aud_seg = ( - AudioSegment.from_wav(audio_path) - .set_channels(1) - .set_sample_width(2) - .set_frame_rate(24000) - ) - res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) - res["speller_asr"] = transcriber_speller(aud_seg.raw_data) - res["wer"] = word_error_rate([res["gold_phone"]], [res["speller_asr"]]) - return res - - -def load_dataset(data_manifest_path: Path): - typer.echo(f"Using data manifest:{data_manifest_path}") - with data_manifest_path.open("r") as pf: - pnr_jsonl = pf.readlines() - pnr_data = [ - preprocess_datapoint(i, data_manifest_path.parent, json.loads(v)) - for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True)) - ] - result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True) - return result - - -@app.command() -def dump_processed_data( - data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), - dump_path: Path = Path("./data/processed_data.json"), -): - typer.echo(f"Using data manifest:{data_manifest_path}") - with data_manifest_path.open("r") as pf: - pnr_jsonl = pf.readlines() - pnr_data = [ - preprocess_datapoint(i, data_manifest_path.parent, json.loads(v)) - for i, v in enumerate(tqdm(pnr_jsonl, position=0, leave=True)) - ] - result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True) - dump_path = Path("./data/processed_data.json") - dump_f = dump_path.open("w") - json.dump(result, dump_f, indent=2) - dump_f.close() + ExtendedPath(dump_path).write_json(corrections) @app.command() def fill_unannotated( - processed_data_path: Path = Path("./data/processed_data.json"), - corrections_path: Path = Path("./data/corrections.json"), + processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"), + corrections_path: Path = Path("./data/valiation_data/corrections.json"), ): processed_data = json.load(processed_data_path.open()) corrections = json.load(corrections_path.open()) annotated_codes = {c["code"] for c in corrections} all_codes = {c["gold_chars"] for c in processed_data} unann_codes = all_codes - annotated_codes - mongo_conn = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation + mongo_conn = get_mongo_conn().test.asr_validation for c in unann_codes: mongo_conn.find_one_and_update( {"type": "correction", "code": c}, @@ -111,8 +143,8 @@ def fill_unannotated( @app.command() def update_corrections( data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), - processed_data_path: Path = Path("./data/processed_data.json"), - corrections_path: Path = Path("./data/corrections.json"), + processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"), + corrections_path: Path = Path("./data/valiation_data/corrections.json"), ): def correct_manifest(manifest_data_gen, corrections_path): corrections = json.load(corrections_path.open()) @@ -168,6 +200,12 @@ def update_corrections( new_data_manifest_path.replace(data_manifest_path) +@app.command() +def clear_mongo_corrections(): + col = get_mongo_conn().test.asr_validation + col.delete_many({"type": "correction"}) + + def main(): app() diff --git a/jasper/data_utils/validation/ui.py b/jasper/data_utils/validation/ui.py index 44cf319..6d495cf 100644 --- a/jasper/data_utils/validation/ui.py +++ b/jasper/data_utils/validation/ui.py @@ -1,27 +1,15 @@ -import json -from io import BytesIO from pathlib import Path import streamlit as st -from nemo.collections.asr.metrics import word_error_rate -import librosa -import librosa.display -import matplotlib.pyplot as plt -from tqdm import tqdm -from pydub import AudioSegment -import pymongo import typer -from .jasper_client import transcriber_pretrained, transcriber_speller +from ..utils import ExtendedPath, get_mongo_conn from .st_rerun import rerun app = typer.Typer() -st.title("ASR Speller Validation") if not hasattr(st, "mongo_connected"): - st.mongoclient = pymongo.MongoClient( - "mongodb://localhost:27017/" - ).test.asr_validation + st.mongoclient = get_mongo_conn().test.asr_validation mongo_conn = st.mongoclient def current_cursor_fn(): @@ -63,80 +51,49 @@ if not hasattr(st, "mongo_connected"): st.mongo_connected = True -# def clear_mongo_corrections(): -# col = pymongo.MongoClient("mongodb://localhost:27017/").test.asr_validation -# col.delete_many({"type": "correction"}) - - -def preprocess_datapoint(idx, rel, sample): - res = dict(sample) - res["real_idx"] = idx - audio_path = rel / Path(sample["audio_filepath"]) - res["audio_path"] = audio_path - res["gold_chars"] = audio_path.stem - aud_seg = ( - AudioSegment.from_wav(audio_path) - .set_channels(1) - .set_sample_width(2) - .set_frame_rate(24000) - ) - res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data) - res["speller_asr"] = transcriber_speller(aud_seg.raw_data) - res["wer"] = word_error_rate([res["text"]], [res["speller_asr"]]) - (y, sr) = librosa.load(audio_path) - plt.tight_layout() - librosa.display.waveplot(y=y, sr=sr) - wav_plot_f = BytesIO() - plt.savefig(wav_plot_f, format="png", dpi=50) - plt.close() - wav_plot_f.seek(0) - res["plot_png"] = wav_plot_f - return res - - -@st.cache(hash_funcs={"rpyc.core.netref.builtins.method": lambda _: None}) -def preprocess_dataset(data_manifest_path: Path): - typer.echo(f"Using data manifest:{data_manifest_path}") - with data_manifest_path.open("r") as pf: - pnr_jsonl = pf.readlines() - pnr_data = [ - preprocess_datapoint(i, data_manifest_path.parent, json.loads(v)) - for i, v in enumerate(tqdm(pnr_jsonl)) - ] - result = sorted(pnr_data, key=lambda x: x["wer"], reverse=True) - return result +@st.cache() +def load_ui_data(validation_ui_data_path: Path): + typer.echo(f"Using validation ui data from :{validation_ui_data_path}") + return ExtendedPath(validation_ui_data_path).read_json() @app.command() def main(manifest: Path): - pnr_data = preprocess_dataset(manifest) + ui_config = load_ui_data(manifest) + asr_data = ui_config["data"] + use_domain_asr = ui_config["use_domain_asr"] sample_no = st.get_current_cursor() - sample = pnr_data[sample_no] - st.markdown( - f"{sample_no+1} of {len(pnr_data)} : **{sample['gold_chars']}** spelled *{sample['text']}*" + sample = asr_data[sample_no] + title_type = 'Speller ' if use_domain_asr else '' + st.title(f"ASR {title_type}Validation") + addl_text = ( + f"spelled *{sample['spoken']}*" if use_domain_asr else "" ) + st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text) new_sample = st.number_input( - "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(pnr_data) + "Go To Sample:", value=sample_no + 1, min_value=1, max_value=len(asr_data) ) if new_sample != sample_no + 1: st.update_cursor(new_sample - 1) st.sidebar.title(f"Details: [{sample['real_idx']}]") - st.sidebar.markdown(f"Gold: **{sample['gold_chars']}**") - st.sidebar.markdown(f"Expected Speech: *{sample['text']}*") + st.sidebar.markdown(f"Gold Text: **{sample['text']}**") + if use_domain_asr: + st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*") st.sidebar.title("Results:") - st.sidebar.text(f"Pretrained:{sample['pretrained_asr']}") - st.sidebar.text(f"Speller:{sample['speller_asr']}") - - st.sidebar.title(f"Speller WER: {sample['wer']:.2f}%") - # (y, sr) = librosa.load(sample["audio_path"]) - # librosa.display.waveplot(y=y, sr=sr) - # st.sidebar.pyplot(fig=sample["plot_fig"]) - st.sidebar.image(sample["plot_png"]) - st.audio(sample["audio_path"].open("rb")) - corrected = sample["gold_chars"] - correction_entry = st.get_correction_entry(sample["gold_chars"]) + st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**") + if use_domain_asr: + st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**") + st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%") + else: + st.sidebar.title(f"Pretrained WER: {sample['pretrained_wer']:.2f}%") + st.sidebar.image(Path(sample["plot_path"]).read_bytes()) + st.audio(Path(sample["audio_path"]).open("rb")) + # set default to text + corrected = sample["text"] + correction_entry = st.get_correction_entry(sample["utterance_id"]) selected_idx = 0 options = ("Correct", "Incorrect", "Inaudible") + # if correction entry is present set the corresponding ui defaults if correction_entry: selected_idx = options.index(correction_entry["value"]["status"]) corrected = correction_entry["value"]["correction"] @@ -148,24 +105,26 @@ def main(manifest: Path): if st.button("Submit"): correct_code = corrected.replace(" ", "").upper() st.update_entry( - sample["gold_chars"], {"status": selected, "correction": correct_code} + sample["utterance_id"], {"status": selected, "correction": correct_code} ) st.update_cursor(sample_no + 1) if correction_entry: st.markdown( f'Your Response: **{correction_entry["value"]["status"]}** Correction: **{correction_entry["value"]["correction"]}**' ) - # real_idx = st.text_input("Go to real-index:", value=sample['real_idx']) - # st.markdown( - # ",".join( - # [ - # "**" + str(p["real_idx"]) + "**" - # if p["real_idx"] == sample["real_idx"] - # else str(p["real_idx"]) - # for p in pnr_data - # ] - # ) - # ) + # if st.button("Previous Untagged"): + # pass + # if st.button("Next Untagged"): + # pass + real_idx = st.number_input( + "Go to real-index", + value=sample["real_idx"], + min_value=0, + max_value=len(asr_data) - 1, + ) + if real_idx != int(sample["real_idx"]): + idx = [i for (i, p) in enumerate(asr_data) if p["real_idx"] == real_idx][0] + st.update_cursor(idx) if __name__ == "__main__": diff --git a/setup.py b/setup.py index 7111587..855d133 100644 --- a/setup.py +++ b/setup.py @@ -22,13 +22,19 @@ extra_requirements = { "matplotlib==3.2.1", "pandas==1.0.3", "tabulate==0.8.7", + "natural==0.2.0", + "num2words==0.5.10", "typer[all]==0.1.1", + "python-slugify==4.0.0", "lenses @ git+https://github.com/ingolemo/python-lenses.git@b2a2a9aa5b61540992d70b2cf36008d0121e8948#egg=lenses", ], "validation": [ "rpyc~=4.1.4", + "pymongo==3.10.1", + "typer[all]==0.1.1", "tqdm~=4.39.0", "librosa==0.7.2", + "matplotlib==3.2.1", "pydub~=0.23.1", "streamlit==0.58.0", "stringcase==1.2.0" @@ -58,6 +64,7 @@ setup( "jasper_asr_trainer = jasper.train:main", "jasper_asr_data_generate = jasper.data_utils.generator:main", "jasper_asr_data_recycle = jasper.data_utils.call_recycler:main", + "jasper_asr_data_validation = jasper.data_utils.validation.process:main", "jasper_asr_data_preprocess = jasper.data_utils.process:main", ] },