From 6d149d282d47f7c8bf3efc7d46f0648db1966dfe Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Tue, 9 Jun 2020 19:16:24 +0530 Subject: [PATCH] 1. added a data extraction type argument 2. cleanup/refactor --- jasper/data/call_recycler.py | 173 +++++++++++++++++++++-------------- 1 file changed, 102 insertions(+), 71 deletions(-) diff --git a/jasper/data/call_recycler.py b/jasper/data/call_recycler.py index 93cb023..09911b1 100644 --- a/jasper/data/call_recycler.py +++ b/jasper/data/call_recycler.py @@ -1,10 +1,16 @@ import typer from pathlib import Path +from enum import Enum + app = typer.Typer() + @app.command() -def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")): +def export_all_logs( + call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True), + domain: str = typer.Option("sia-data.agaralabs.com", show_default=True), +): from .utils import get_mongo_conn from collections import defaultdict from ruamel.yaml import YAML @@ -14,14 +20,14 @@ def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")): caller_calls = defaultdict(lambda: []) for call in mongo_coll.find(): sysid = call["SystemID"] - call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}" + call_uri = f"http://{domain}/calls/{sysid}" caller = call["Caller"] caller_calls[caller].append(call_uri) caller_list = [] for caller in caller_calls: caller_list.append({"name": caller, "calls": caller_calls[caller]}) output_yaml = {"users": caller_list} - typer.echo("exporting call logs to yaml file") + typer.echo(f"exporting call logs to yaml file at {call_logs_file}") with call_logs_file.open("w") as yf: yaml.dump(output_yaml, yf) @@ -30,7 +36,8 @@ def export_all_logs(call_logs_file: Path = Path("./call_sia_logs.yaml")): def export_calls_between( start_cid: str, end_cid: str, - call_logs_file: Path = Path("./call_sia_logs.yaml"), + call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True), + domain: str = typer.Option("sia-data.agaralabs.com", show_default=True), mongo_port: int = 27017, ): from collections import defaultdict @@ -51,28 +58,38 @@ def export_calls_between( ) for call in call_query: sysid = call["SystemID"] - call_uri = f"http://sia-data.agaralabs.com/calls/{sysid}" + call_uri = f"http://{domain}/calls/{sysid}" caller = call["Caller"] caller_calls[caller].append(call_uri) caller_list = [] for caller in caller_calls: caller_list.append({"name": caller, "calls": caller_calls[caller]}) output_yaml = {"users": caller_list} - typer.echo("exporting call logs to yaml file") + typer.echo(f"exporting call logs to yaml file at {call_logs_file}") with call_logs_file.open("w") as yf: yaml.dump(output_yaml, yf) +class ExtractionType(str, Enum): + flow = "flow" + data = "data" + + @app.command() def analyze( leaderboard: bool = False, plot_calls: bool = False, extract_data: bool = False, + extraction_type: ExtractionType = typer.Option( + ExtractionType.data, show_default=True + ), download_only: bool = False, call_logs_file: Path = typer.Option(Path("./call_logs.yaml"), show_default=True), output_dir: Path = Path("./data"), data_name: str = None, - mongo_uri: str = typer.Option("mongodb://localhost:27017/test.calls", show_default=True), + mongo_uri: str = typer.Option( + "mongodb://localhost:27017/test.calls", show_default=True + ), ): from urllib.parse import urlsplit @@ -145,23 +162,77 @@ def analyze( def chunk_n(evs, n): return [evs[i * n : (i + 1) * n] for i in range((len(evs) + n - 1) // n)] - def get_data_points(utter_events, td_fn): - data_points = [] - for evs in chunk_n(utter_events, 3): - try: - assert evs[0]["Type"] == "CONV_RESULT" - assert evs[1]["Type"] == "STARTED_SPEAKING" - assert evs[2]["Type"] == "STOPPED_SPEAKING" - start_time = td_fn(evs[1]).total_seconds() - 1.5 - end_time = td_fn(evs[2]).total_seconds() - code = evs[0]["Msg"] - data_points.append( - {"start_time": start_time, "end_time": end_time, "code": code} - ) - except AssertionError: - # skipping invalid data_points - pass - return data_points + if extraction_type == ExtractionType.data: + + def is_utter_event(ev): + return ( + (ev["Author"] == "CONV" or ev["Author"] == "ASR") + and (ev["Type"] != "DEBUG") + and ev["Type"] != "ASR_RESULT" + ) + + def get_data_points(utter_events, td_fn): + data_points = [] + for evs in chunk_n(utter_events, 3): + try: + assert evs[0]["Type"] == "CONV_RESULT" + assert evs[1]["Type"] == "STARTED_SPEAKING" + assert evs[2]["Type"] == "STOPPED_SPEAKING" + start_time = td_fn(evs[1]).total_seconds() - 1.5 + end_time = td_fn(evs[2]).total_seconds() + spoken = evs[0]["Msg"] + data_points.append( + {"start_time": start_time, "end_time": end_time, "code": spoken} + ) + except AssertionError: + # skipping invalid data_points + pass + return data_points + + def text_extractor(spoken): + return re.search(r"'(.*)'", spoken).groups(0)[0] if len(spoken) > 6 else spoken + + elif extraction_type == ExtractionType.flow: + + def is_final_asr_event_or_spoken(ev): + pld = json.loads(ev["Payload"]) + return ( + pld["AsrResult"]["Results"][0]["IsFinal"] + if ev["Type"] == "ASR_RESULT" + else True + ) + + def is_utter_event(ev): + return ( + ev["Author"] == "CONV" + or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev)) + ) and (ev["Type"] != "DEBUG") + + def get_data_points(utter_events, td_fn): + data_points = [] + for evs in chunk_n(utter_events, 4): + try: + assert len(evs) == 4 + assert evs[0]["Type"] == "CONV_RESULT" + assert evs[1]["Type"] == "STARTED_SPEAKING" + assert evs[2]["Type"] == "ASR_RESULT" + assert evs[3]["Type"] == "STOPPED_SPEAKING" + start_time = td_fn(evs[1]).total_seconds() - 1.5 + end_time = td_fn(evs[2]).total_seconds() + conv_msg = evs[0]["Msg"] + if 'full name' in conv_msg.lower(): + pld = json.loads(evs[2]["Payload"]) + spoken = pld["AsrResult"]["Results"][0]["Alternatives"][0]['Transcript'] + data_points.append( + {"start_time": start_time, "end_time": end_time, "code": spoken} + ) + except AssertionError: + # skipping invalid data_points + pass + return data_points + + def text_extractor(spoken): + return spoken def process_call(call_obj): call_meta = get_call_meta(call_obj) @@ -184,13 +255,6 @@ def analyze( get_ev_fev_timedelta = gen_ev_fev_timedelta(first_audio_ev) - def is_utter_event(ev): - return ( - (ev["Author"] == "CONV" or ev["Author"] == "ASR") - and (ev["Type"] != "DEBUG") - and ev["Type"] != "ASR_RESULT" - ) - uevs = list(filter(is_utter_event, call_events)) ev_count = len(uevs) utter_events = uevs[: ev_count - ev_count % 3] @@ -201,36 +265,6 @@ def analyze( s3_wav_url_p.netloc, s3_wav_url_p.path[1:], str(saved_wav_path) ) - # %config InlineBackend.figure_format = "retina" - - def plot_events(y, sr, utter_events, file_path): - plt.figure(figsize=(16, 12)) - librosa.display.waveplot(y=y, sr=sr) - # plt.tight_layout() - for evs in chunk_n(utter_events, 3): - assert evs[0]["Type"] == "CONV_RESULT" - assert evs[1]["Type"] == "STARTED_SPEAKING" - assert evs[2]["Type"] == "STOPPED_SPEAKING" - for ev in evs: - # print(ev["Type"]) - ev_type = ev["Type"] - pos = get_ev_fev_timedelta(ev).total_seconds() - if ev_type == "STARTED_SPEAKING": - pos = pos - 1.5 - plt.axvline(pos) # , label="pyplot vertical line") - plt.text( - pos, - 0.2, - f"event:{ev_type}:{ev['Msg']}", - rotation=90, - horizontalalignment="left" - if ev_type != "STOPPED_SPEAKING" - else "right", - verticalalignment="center", - ) - plt.title("Monophonic") - plt.savefig(file_path, format="png") - return { "wav_path": saved_wav_path, "num_samples": len(utter_events) // 3, @@ -263,7 +297,6 @@ def analyze( call_lens = lens["users"].Each()["calls"].Each() call_lens.modify(ensure_call)(call_logs) - # @plot_app.command() def plot_calls_data(): def plot_data_points(y, sr, data_points, file_path): plt.figure(figsize=(16, 12)) @@ -317,16 +350,14 @@ def analyze( .set_frame_rate(24000) ) for dp_id, dp in enumerate(data_points): - start, end, code = dp["start_time"], dp["end_time"], dp["code"] - code_seg = call_seg[start * 1000 : end * 1000] - code_fb = BytesIO() - code_seg.export(code_fb, format="wav") - code_wav = code_fb.getvalue() + start, end, spoken = dp["start_time"], dp["end_time"], dp["code"] + spoken_seg = call_seg[start * 1000 : end * 1000] + spoken_fb = BytesIO() + spoken_seg.export(spoken_fb, format="wav") + spoken_wav = spoken_fb.getvalue() # search for actual pnr code and handle plain codes as well - extracted_code = ( - re.search(r"'(.*)'", code).groups(0)[0] if len(code) > 6 else code - ) - yield extracted_code, code_seg.duration_seconds, code_wav + extracted_code = text_extractor(spoken) + yield extracted_code, spoken_seg.duration_seconds, spoken_wav call_lens = lens["users"].Each()["calls"].Each() call_stats = call_lens.modify(retrieve_processed_callmeta)(call_logs)