From 8e238c254ef6394e413aa49037433154c941307a Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Wed, 17 Jun 2020 19:11:15 +0530 Subject: [PATCH] 1. added start delay arg in call recycler 2. implement ui_dump/manifest writer in call_recycler itself 3. refactored call data point plotter 4. added sample-ui task-ui on the validation process 5. implemented call-quality stats using corrections from mongo 6. support deleting cursors on mongo 7. implement multiple task support on validation ui based on task_id mongo field --- .gitignore | 1 + jasper/data/call_recycler.py | 72 +++++++++++++----- jasper/data/utils.py | 119 ++++++++++++++++++++++++++++-- jasper/data/validation/process.py | 100 +++++++++++++++++++++---- jasper/data/validation/ui.py | 36 ++++++--- 5 files changed, 280 insertions(+), 48 deletions(-) diff --git a/.gitignore b/.gitignore index bda7618..3676125 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ /train/ .env* *.yaml +*.yml *.json diff --git a/jasper/data/call_recycler.py b/jasper/data/call_recycler.py index 4de5ef9..ad7eba8 100644 --- a/jasper/data/call_recycler.py +++ b/jasper/data/call_recycler.py @@ -93,8 +93,8 @@ def copy_metas( def copy_meta(uri): cid = get_cid(uri) - saved_meta_path = call_meta_dir / Path(f'{cid}.json') - dest_meta_path = meta_dir / Path(f'{cid}.json') + saved_meta_path = call_meta_dir / Path(f"{cid}.json") + dest_meta_path = meta_dir / Path(f"{cid}.json") if not saved_meta_path.exists(): print(f"{saved_meta_path} not found") copy2(saved_meta_path, dest_meta_path) @@ -106,7 +106,6 @@ def copy_metas( download_meta_audio() - class ExtractionType(str, Enum): flow = "flow" data = "data" @@ -120,6 +119,7 @@ def analyze( extraction_type: ExtractionType = typer.Option( ExtractionType.data, show_default=True ), + start_delay: float = 3, download_only: bool = False, call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True), output_dir: Path = Path("./data"), @@ -146,7 +146,7 @@ def analyze( import matplotlib.pyplot as plt import matplotlib from tqdm import tqdm - from .utils import asr_data_writer, get_mongo_coll + from .utils import ui_dump_manifest_writer, get_mongo_coll from pydub import AudioSegment from natural.date import compress @@ -215,7 +215,7 @@ def analyze( assert evs[0]["Type"] == "CONV_RESULT" assert evs[1]["Type"] == "STARTED_SPEAKING" assert evs[2]["Type"] == "STOPPED_SPEAKING" - start_time = td_fn(evs[1]).total_seconds() - 2 + start_time = td_fn(evs[1]).total_seconds() - start_delay end_time = td_fn(evs[2]).total_seconds() spoken = evs[0]["Msg"] data_points.append( @@ -227,7 +227,11 @@ def analyze( return data_points def text_extractor(spoken): - return re.search(r"'(.*)'", spoken).groups(0)[0] if len(spoken) > 6 and re.search(r"'(.*)'", spoken) else spoken + return ( + re.search(r"'(.*)'", spoken).groups(0)[0] + if len(spoken) > 6 and re.search(r"'(.*)'", spoken) + else spoken + ) elif extraction_type == ExtractionType.flow: @@ -254,14 +258,20 @@ def analyze( assert evs[1]["Type"] == "STARTED_SPEAKING" assert evs[2]["Type"] == "ASR_RESULT" assert evs[3]["Type"] == "STOPPED_SPEAKING" - start_time = td_fn(evs[1]).total_seconds() - 1.5 + start_time = td_fn(evs[1]).total_seconds() - start_delay end_time = td_fn(evs[2]).total_seconds() conv_msg = evs[0]["Msg"] - if 'full name' in conv_msg.lower(): + if "full name" in conv_msg.lower(): pld = json.loads(evs[2]["Payload"]) - spoken = pld["AsrResult"]["Results"][0]["Alternatives"][0]['Transcript'] + spoken = pld["AsrResult"]["Results"][0]["Alternatives"][0][ + "Transcript" + ] data_points.append( - {"start_time": start_time, "end_time": end_time, "code": spoken} + { + "start_time": start_time, + "end_time": end_time, + "code": spoken, + } ) except AssertionError: # skipping invalid data_points @@ -330,6 +340,25 @@ def analyze( process_meta["data_points"] = data_points return {"url": uri, "meta": meta, "duration": duration, "process": process_meta} + def retrieve_callmeta(call_uri): + uri = call_uri["call_uri"] + name = call_uri["name"] + cid = get_cid(uri) + meta = mongo_collection.find_one({"SystemID": cid}) + duration = meta["EndTS"] - meta["StartTS"] + process_meta = process_call(meta) + data_points = get_data_points( + process_meta["utter_events"], process_meta["first_event_fn"] + ) + process_meta["data_points"] = data_points + return { + "url": uri, + "name": name, + "meta": meta, + "duration": duration, + "process": process_meta, + } + def download_meta_audio(): call_lens = lens["users"].Each()["calls"].Each() call_lens.modify(ensure_call)(call_logs) @@ -379,7 +408,7 @@ def analyze( pprint(call_plots) def extract_data_points(): - def gen_data_values(saved_wav_path, data_points): + def gen_data_values(saved_wav_path, data_points, caller_name): call_seg = ( AudioSegment.from_wav(saved_wav_path) .set_channels(1) @@ -394,23 +423,32 @@ def analyze( spoken_wav = spoken_fb.getvalue() # search for actual pnr code and handle plain codes as well extracted_code = text_extractor(spoken) - yield extracted_code, spoken_seg.duration_seconds, spoken_wav + yield extracted_code, spoken_seg.duration_seconds, spoken_wav, caller_name, spoken_seg call_lens = lens["users"].Each()["calls"].Each() - call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs) + + def assign_user_call(uc): + return ( + lens["calls"] + .Each() + .modify(lambda c: {"call_uri": c, "name": uc["name"]})(uc) + ) + + user_call_logs = lens["users"].Each().modify(assign_user_call)(call_logs) + call_stats = call_lens.modify(retrieve_callmeta)(user_call_logs) call_objs = call_lens.collect()(call_stats) def data_source(): for call_obj in tqdm(call_objs): - saved_wav_path, data_points, sys_id = ( + saved_wav_path, data_points, name = ( call_obj["process"]["wav_path"], call_obj["process"]["data_points"], - call_obj["meta"]["SystemID"], + call_obj["name"], ) - for dp in gen_data_values(saved_wav_path, data_points): + for dp in gen_data_values(saved_wav_path, data_points, name): yield dp - asr_data_writer(call_asr_data, dataset_name, data_source()) + ui_dump_manifest_writer(call_asr_data, dataset_name, data_source()) def show_leaderboard(): def compute_user_stats(call_stat): diff --git a/jasper/data/utils.py b/jasper/data/utils.py index eda1a65..d4d1c7f 100644 --- a/jasper/data/utils.py +++ b/jasper/data/utils.py @@ -9,6 +9,14 @@ import pymongo from slugify import slugify from uuid import uuid4 from num2words import num2words +from jasper.client import transcribe_gen +from nemo.collections.asr.metrics import word_error_rate +import matplotlib.pyplot as plt +import librosa +import librosa.display +from tqdm import tqdm +from functools import partial +from concurrent.futures import ThreadPoolExecutor def manifest_str(path, dur, text): @@ -57,11 +65,12 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False): asr_manifest = dataset_dir / Path("manifest.json") num_datapoints = 0 with asr_manifest.open("w") as mf: + print(f"writing manifest to {asr_manifest}") for transcript, audio_dur, wav_data in asr_data_source: fname = str(uuid4()) + "_" + slugify(transcript, max_length=8) - pnr_af = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") - pnr_af.write_bytes(wav_data) - rel_pnr_path = pnr_af.relative_to(dataset_dir) + audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") + audio_file.write_bytes(wav_data) + rel_pnr_path = audio_file.relative_to(dataset_dir) manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript) mf.write(manifest) if verbose: @@ -70,6 +79,94 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False): return num_datapoints +def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=False): + dataset_dir = output_dir / Path(dataset_name) + (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) + ui_dump_file = dataset_dir / Path("ui_dump.json") + (dataset_dir / Path("wav_plots")).mkdir(parents=True, exist_ok=True) + asr_manifest = dataset_dir / Path("manifest.json") + num_datapoints = 0 + ui_dump = { + "use_domain_asr": False, + "annotation_only": False, + "enable_plots": True, + "data": [], + } + data_funcs = [] + transcriber_pretrained = transcribe_gen(asr_port=8044) + with asr_manifest.open("w") as mf: + print(f"writing manifest to {asr_manifest}") + + def data_fn( + transcript, + audio_dur, + wav_data, + caller_name, + aud_seg, + fname, + audio_path, + num_datapoints, + rel_pnr_path, + ): + pretrained_result = transcriber_pretrained(aud_seg.raw_data) + pretrained_wer = word_error_rate([transcript], [pretrained_result]) + wav_plot_path = ( + dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png") + ) + if not wav_plot_path.exists(): + plot_seg(wav_plot_path, audio_path) + return { + "audio_filepath": str(rel_pnr_path), + "duration": round(audio_dur, 1), + "text": transcript, + "real_idx": num_datapoints, + "audio_path": audio_path, + "spoken": transcript, + "caller": caller_name, + "utterance_id": fname, + "pretrained_asr": pretrained_result, + "pretrained_wer": pretrained_wer, + "plot_path": str(wav_plot_path), + } + + for transcript, audio_dur, wav_data, caller_name, aud_seg in asr_data_source: + fname = str(uuid4()) + "_" + slugify(transcript, max_length=8) + audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") + audio_file.write_bytes(wav_data) + audio_path = str(audio_file) + rel_pnr_path = audio_file.relative_to(dataset_dir) + manifest = manifest_str(str(rel_pnr_path), audio_dur, transcript) + mf.write(manifest) + data_funcs.append( + partial( + data_fn, + transcript, + audio_dur, + wav_data, + caller_name, + aud_seg, + fname, + audio_path, + num_datapoints, + rel_pnr_path, + ) + ) + num_datapoints += 1 + with ThreadPoolExecutor() as exe: + print("starting all plot/transcription tasks") + dump_data = list( + tqdm( + exe.map(lambda x: x(), data_funcs), + position=0, + leave=True, + total=len(data_funcs), + ) + ) + ui_dump["data"] = dump_data + ExtendedPath(ui_dump_file).write_json(ui_dump) + return num_datapoints + + def asr_manifest_reader(data_manifest_path: Path): print(f"reading manifest from {data_manifest_path}") with data_manifest_path.open("r") as pf: @@ -95,12 +192,12 @@ class ExtendedPath(type(Path())): """docstring for ExtendedPath.""" def read_json(self): - print(f'reading json from {self}') + print(f"reading json from {self}") with self.open("r") as jf: return json.load(jf) def write_json(self, data): - print(f'writing json to {self}') + print(f"writing json to {self}") self.parent.mkdir(parents=True, exist_ok=True) with self.open("w") as jf: return json.dump(data, jf, indent=2) @@ -109,7 +206,7 @@ class ExtendedPath(type(Path())): def get_mongo_coll(uri="mongodb://localhost:27017/test.calls"): ud = pymongo.uri_parser.parse_uri(uri) conn = pymongo.MongoClient(uri) - return conn[ud['database']][ud['collection']] + return conn[ud["database"]][ud["collection"]] def get_mongo_conn(host="", port=27017, db="test", col="calls"): @@ -127,6 +224,16 @@ def strip_silence(sound): return sound[start_trim : duration - end_trim] +def plot_seg(wav_plot_path, audio_path): + fig = plt.Figure() + ax = fig.add_subplot() + (y, sr) = librosa.load(audio_path) + librosa.display.waveplot(y=y, sr=sr, ax=ax) + with wav_plot_path.open("wb") as wav_plot_f: + fig.set_tight_layout(True) + fig.savefig(wav_plot_f, format="png", dpi=50) + + def main(): for c in random_pnr_generator(): print(c) diff --git a/jasper/data/validation/process.py b/jasper/data/validation/process.py index 44133ef..7013577 100644 --- a/jasper/data/validation/process.py +++ b/jasper/data/validation/process.py @@ -12,6 +12,7 @@ from ..utils import ( asr_manifest_reader, asr_manifest_writer, get_mongo_conn, + plot_seg, ) app = typer.Typer() @@ -20,9 +21,6 @@ app = typer.Typer() def preprocess_datapoint( idx, rel_root, sample, use_domain_asr, annotation_only, enable_plots ): - import matplotlib.pyplot as plt - import librosa - import librosa.display from pydub import AudioSegment from nemo.collections.asr.metrics import word_error_rate from jasper.client import transcribe_gen @@ -61,14 +59,7 @@ def preprocess_datapoint( rel_root / Path("wav_plots") / Path(audio_path.name).with_suffix(".png") ) if not wav_plot_path.exists(): - fig = plt.Figure() - ax = fig.add_subplot() - (y, sr) = librosa.load(audio_path) - librosa.display.waveplot(y=y, sr=sr, ax=ax) - with wav_plot_path.open("wb") as wav_plot_f: - fig.set_tight_layout(True) - fig.savefig(wav_plot_f, format="png", dpi=50) - # fig.close() + plot_seg(wav_plot_path, audio_path) res["plot_path"] = str(wav_plot_path) return res except BaseException as e: @@ -131,17 +122,66 @@ def dump_ui( result = sorted(pnr_data, key=lambda x: x[wer_key], reverse=True) ui_config = { "use_domain_asr": use_domain_asr, - "data": result, "annotation_only": annotation_only, "enable_plots": enable_plots, + "data": result, } ExtendedPath(dump_path).write_json(ui_config) +@app.command() +def sample_ui( + data_name: str = typer.Option("call_upwork_train_cnd", show_default=True), + dump_dir: Path = Path("./data/asr_data"), + dump_file: Path = Path("ui_dump.json"), + sample_count: int = typer.Option(80, show_default=True), + sample_file: Path = Path("sample_dump.json"), +): + import pandas as pd + + processed_data_path = dump_dir / Path(data_name) / dump_file + sample_path = dump_dir / Path(data_name) / sample_file + processed_data = ExtendedPath(processed_data_path).read_json() + df = pd.DataFrame(processed_data["data"]) + samples_per_caller = sample_count // len(df["caller"].unique()) + caller_samples = pd.concat( + [g.sample(samples_per_caller) for (c, g) in df.groupby("caller")] + ) + caller_samples = caller_samples.reset_index(drop=True) + caller_samples["real_idx"] = caller_samples.index + sample_data = caller_samples.to_dict("records") + processed_data["data"] = sample_data + typer.echo(f"sampling {sample_count} datapoints") + ExtendedPath(sample_path).write_json(processed_data) + + +@app.command() +def task_ui( + data_name: str = typer.Option("call_upwork_train_cnd", show_default=True), + dump_dir: Path = Path("./data/asr_data"), + dump_file: Path = Path("ui_dump.json"), + task_count: int = typer.Option(4, show_default=True), + task_file: str = "task_dump", +): + import pandas as pd + import numpy as np + + processed_data_path = dump_dir / Path(data_name) / dump_file + processed_data = ExtendedPath(processed_data_path).read_json() + df = pd.DataFrame(processed_data["data"]).sample(frac=1).reset_index(drop=True) + for t_idx, task_f in enumerate(np.array_split(df, task_count)): + task_f = task_f.reset_index(drop=True) + task_f["real_idx"] = task_f.index + task_data = task_f.to_dict("records") + processed_data["data"] = task_data + task_path = dump_dir / Path(data_name) / Path(task_file + f"-{t_idx}.json") + ExtendedPath(task_path).write_json(processed_data) + + @app.command() def dump_corrections( data_name: str = typer.Option("call_alphanum", show_default=True), - dump_dir: Path = Path("./data/valiation_data"), + dump_dir: Path = Path("./data/asr_data"), dump_fname: Path = Path("corrections.json"), ): dump_path = dump_dir / Path(data_name) / dump_fname @@ -152,6 +192,38 @@ def dump_corrections( ExtendedPath(dump_path).write_json(corrections) +@app.command() +def caller_quality( + data_name: str = typer.Option("call_upwork_train_cnd", show_default=True), + dump_dir: Path = Path("./data/asr_data"), + dump_fname: Path = Path("ui_dump.json"), + correction_fname: Path = Path("corrections.json"), +): + import copy + import pandas as pd + + dump_path = dump_dir / Path(data_name) / dump_fname + correction_path = dump_dir / Path(data_name) / correction_fname + dump_data = ExtendedPath(dump_path).read_json() + + dump_map = {d["utterance_id"]: d for d in dump_data["data"]} + correction_data = ExtendedPath(correction_path).read_json() + + def correction_dp(c): + dp = copy.deepcopy(dump_map[c["code"]]) + dp["valid"] = c["value"]["status"] == "Correct" + return dp + + corrected_dump = [correction_dp(c) for c in correction_data] + df = pd.DataFrame(corrected_dump) + print(f"Total samples: {len(df)}") + for (c, g) in df.groupby("caller"): + total = len(g) + valid = len(g[g["valid"] == True]) + valid_rate = valid * 100 / total + print(f"Caller: {c} Valid%:{valid_rate:.2f} of {total} samples") + + @app.command() def fill_unannotated( data_name: str = typer.Option("call_alphanum", show_default=True), @@ -329,7 +401,9 @@ def clear_mongo_corrections(): if delete: col = get_mongo_conn(col="asr_validation") col.delete_many({"type": "correction"}) + col.delete_many({"type": "current_cursor"}) typer.echo("deleted mongo collection.") + return typer.echo("Aborted") diff --git a/jasper/data/validation/ui.py b/jasper/data/validation/ui.py index b10b88f..3915aeb 100644 --- a/jasper/data/validation/ui.py +++ b/jasper/data/validation/ui.py @@ -2,6 +2,7 @@ from pathlib import Path import streamlit as st import typer +from uuid import uuid4 from ..utils import ExtendedPath, get_mongo_conn from .st_rerun import rerun @@ -11,25 +12,25 @@ app = typer.Typer() if not hasattr(st, "mongo_connected"): st.mongoclient = get_mongo_conn(col="asr_validation") mongo_conn = st.mongoclient + st.task_id = str(uuid4()) def current_cursor_fn(): # mongo_conn = st.mongoclient - cursor_obj = mongo_conn.find_one({"type": "current_cursor"}) + cursor_obj = mongo_conn.find_one( + {"type": "current_cursor", "task_id": st.task_id} + ) cursor_val = cursor_obj["cursor"] return cursor_val def update_cursor_fn(val=0): mongo_conn.find_one_and_update( - {"type": "current_cursor"}, - {"$set": {"type": "current_cursor", "cursor": val}}, + {"type": "current_cursor", "task_id": st.task_id}, + {"$set": {"type": "current_cursor", "task_id": st.task_id, "cursor": val}}, upsert=True, ) rerun() def get_correction_entry_fn(code): - # mongo_conn = st.mongoclient - # cursor_obj = mongo_conn.find_one({"type": "correction", "code": code}) - # cursor_val = cursor_obj["cursor"] return mongo_conn.find_one( {"type": "correction", "code": code}, projection={"_id": False} ) @@ -37,18 +38,25 @@ if not hasattr(st, "mongo_connected"): def update_entry_fn(code, value): mongo_conn.find_one_and_update( {"type": "correction", "code": code}, - {"$set": {"value": value}}, + {"$set": {"value": value, "task_id": st.task_id}}, upsert=True, ) - cursor_obj = mongo_conn.find_one({"type": "current_cursor"}) - if not cursor_obj: - update_cursor_fn(0) + def set_task_fn(mf_path): + task_path = mf_path.parent / Path(f"task-{st.task_id}.lck") + if not task_path.exists(): + print(f"creating task lock at {task_path}") + task_path.touch() + st.get_current_cursor = current_cursor_fn st.update_cursor = update_cursor_fn st.get_correction_entry = get_correction_entry_fn st.update_entry = update_entry_fn + st.set_task = set_task_fn st.mongo_connected = True + cursor_obj = mongo_conn.find_one({"type": "current_cursor", "task_id": st.task_id}) + if not cursor_obj: + update_cursor_fn(0) @st.cache() @@ -59,6 +67,7 @@ def load_ui_data(validation_ui_data_path: Path): @app.command() def main(manifest: Path): + st.set_task(manifest) ui_config = load_ui_data(manifest) asr_data = ui_config["data"] use_domain_asr = ui_config.get("use_domain_asr", True) @@ -70,10 +79,11 @@ def main(manifest: Path): st.update_cursor(0) sample = asr_data[sample_no] title_type = "Speller " if use_domain_asr else "" + task_uid = st.task_id.rsplit("-", 1)[1] if annotation_only: - st.title(f"ASR Annotation") + st.title(f"ASR Annotation - # {task_uid}") else: - st.title(f"ASR {title_type}Validation") + st.title(f"ASR {title_type}Validation - # {task_uid}") addl_text = f"spelled *{sample['spoken']}*" if use_domain_asr else "" st.markdown(f"{sample_no+1} of {len(asr_data)} : **{sample['text']}**" + addl_text) new_sample = st.number_input( @@ -88,6 +98,8 @@ def main(manifest: Path): st.sidebar.markdown(f"Expected Spelled: *{sample['spoken']}*") st.sidebar.title("Results:") st.sidebar.markdown(f"Pretrained: **{sample['pretrained_asr']}**") + if "caller" in sample: + st.sidebar.markdown(f"Caller: **{sample['caller']}**") if use_domain_asr: st.sidebar.markdown(f"Domain: **{sample['domain_asr']}**") st.sidebar.title(f"Speller WER: {sample['domain_wer']:.2f}%")