From 069392d09893106bc8b66a3a436af60224fbda4a Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Mon, 29 Jun 2020 14:24:56 +0530 Subject: [PATCH] 1. added a test generator and slu evaluator 2. ui dump now include gcp results 3. showing default option for more args validation process commands --- jasper/data/call_recycler.py | 16 +-- jasper/data/slu_evaluator.py | 180 ++++++++++++++++++++++++++ jasper/data/test_generator.py | 99 ++++++++++++++ jasper/data/utils.py | 206 ++++++++++++++++++++++++++++++ jasper/data/validation/process.py | 14 +- setup.py | 3 + 6 files changed, 500 insertions(+), 18 deletions(-) create mode 100644 jasper/data/slu_evaluator.py create mode 100644 jasper/data/test_generator.py diff --git a/jasper/data/call_recycler.py b/jasper/data/call_recycler.py index 881159b..fbab009 100644 --- a/jasper/data/call_recycler.py +++ b/jasper/data/call_recycler.py @@ -147,7 +147,7 @@ def analyze( import matplotlib.pyplot as plt import matplotlib from tqdm import tqdm - from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll + from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll, get_call_logs from pydub import AudioSegment from natural.date import compress @@ -170,18 +170,6 @@ def analyze( call_logs = yaml.load(call_logs_file.read_text()) - def get_call_meta(call_obj): - meta_s3_uri = call_obj["DataURI"] - s3_event_url_p = urlsplit(meta_s3_uri) - saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name) - if not saved_meta_path.exists(): - print(f"downloading : {saved_meta_path} from {meta_s3_uri}") - s3.download_file( - s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path) - ) - call_metas = json.load(saved_meta_path.open()) - return call_metas - def gen_ev_fev_timedelta(fev): fev_p = Timestamp() fev_p.FromJsonString(fev["CreatedTS"]) @@ -283,7 +271,7 @@ def analyze( return spoken def process_call(call_obj): - call_meta = get_call_meta(call_obj) + call_meta = get_call_logs(call_obj, s3, call_meta_dir) call_events = call_meta["Events"] def is_writer_uri_event(ev): diff --git a/jasper/data/slu_evaluator.py b/jasper/data/slu_evaluator.py new file mode 100644 index 0000000..b29efdd --- /dev/null +++ b/jasper/data/slu_evaluator.py @@ -0,0 +1,180 @@ +import typer +from pathlib import Path +import json + +# from .utils import generate_dates, asr_test_writer + +app = typer.Typer() + + +def run_test(reg_path, coll, s3, call_meta_dir, city_code, test_path): + from time import sleep + import subprocess + from .utils import ExtendedPath, get_call_logs + + coll.delete_many({"CallID": test_path.name}) + # test_path = dump_dir / data_name / test_file + # "../saas_reg/regression/run.sh -f data/asr_data/call_upwork_test_cnd_cities/asr_test.reg" + test_output = subprocess.run( + ["/bin/bash", "-c", f"{str(reg_path)} --addr [::]:15400 -f {str(test_path)}"] + ) + if test_output.returncode != 0: + print("Error running test {test_file}") + return + + def get_meta(): + call_meta = coll.find_one({"CallID": test_path.name}) + if call_meta: + return call_meta + else: + sleep(2) + return get_meta() + + call_meta = get_meta() + call_logs = get_call_logs(call_meta, s3, call_meta_dir) + call_events = call_logs["Events"] + + test_data_path = test_path.with_suffix(".result.json") + test_data = ExtendedPath(test_data_path).read_json() + + def is_final_asr_event_or_spoken(ev): + pld = json.loads(ev["Payload"]) + return ( + pld["AsrResult"]["Results"][0]["IsFinal"] + if ev["Type"] == "ASR_RESULT" + else False + ) + + def is_test_event(ev): + return ( + ev["Author"] == "NLU" + or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev)) + ) and (ev["Type"] != "DEBUG") + + test_evs = list(filter(is_test_event, call_events)) + if len(test_evs) == 2: + try: + asr_payload = test_evs[0]["Payload"] + asr_result = json.loads(asr_payload)["AsrResult"]["Results"][0] + alt_tscripts = [alt["Transcript"] for alt in asr_result["Alternatives"]] + gcp_result = "|".join(alt_tscripts) + entity_asr = asr_result["AsrDynamicResults"][0]["Candidate"]["Transcript"] + nlu_payload = test_evs[1]["Payload"] + nlu_result_payload = json.loads(nlu_payload)["NluResults"] + entity = test_data[0]["entity"] + text = test_data[0]["text"] + audio_filepath = test_data[0]["audio_filepath"] + pretrained_asr = test_data[0]["pretrained_asr"] + nlu_entity = list(json.loads(nlu_result_payload)["Entities"].values())[0] + asr_entity = city_code[entity] if entity in city_code else "UNKNOWN" + entities_match = asr_entity == nlu_entity + result = "Success" if entities_match else "Fail" + return { + "expected_entity": entity, + "text": text, + "audio_filepath": audio_filepath, + "pretrained_asr": pretrained_asr, + "entity_asr": entity_asr, + "google_asr": gcp_result, + "nlu_result": nlu_result_payload, + "asr_entity": asr_entity, + "nlu_entity": nlu_entity, + "result": result, + } + except Exception: + return { + "expected_entity": test_data[0]["entity"], + "text": test_data[0]["text"], + "audio_filepath": test_data[0]["audio_filepath"], + "pretrained_asr": test_data[0]["pretrained_asr"], + "entity_asr": "", + "google_asr": "", + "nlu_result": "", + "asr_entity": "", + "nlu_entity": "", + "result": "Error", + } + else: + return { + "expected_entity": test_data[0]["entity"], + "text": test_data[0]["text"], + "audio_filepath": test_data[0]["audio_filepath"], + "pretrained_asr": test_data[0]["pretrained_asr"], + "entity_asr": "", + "google_asr": "", + "nlu_result": "", + "asr_entity": "", + "nlu_entity": "", + "result": "Empty", + } + + +@app.command() +def evaluate_slu( + # conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True), + data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True), + # extraction_key: str = "Cities", + dump_dir: Path = Path("./data/asr_data"), + call_meta_dir: Path = Path("./data/call_metas"), + test_file_pref: str = "asr_test", + mongo_uri: str = typer.Option( + "mongodb://localhost:27017/test.calls", show_default=True + ), + test_results: Path = Path("./data/results.csv"), + airport_codes: Path = Path("./airports_code.csv"), + reg_path: Path = Path("../saas_reg/regression/run.sh"), + test_id: str = "5ef481f27031edf6910e94e0", +): + # import json + from .utils import get_mongo_coll + import pandas as pd + import boto3 + from concurrent.futures import ThreadPoolExecutor + from functools import partial + + # import subprocess + # from time import sleep + import csv + from tqdm import tqdm + + s3 = boto3.client("s3") + df = pd.read_csv(airport_codes)[["iata", "city"]] + city_code = pd.Series(df["iata"].values, index=df["city"]).to_dict() + + test_files = list((dump_dir / data_name).glob(test_file_pref + "*.reg")) + coll = get_mongo_coll(mongo_uri) + with test_results.open("w") as csvfile: + fieldnames = [ + "expected_entity", + "text", + "audio_filepath", + "pretrained_asr", + "entity_asr", + "google_asr", + "nlu_result", + "asr_entity", + "nlu_entity", + "result", + ] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + with ThreadPoolExecutor(max_workers=8) as exe: + print("starting all loading tasks") + for test_result in tqdm( + exe.map( + partial(run_test, reg_path, coll, s3, call_meta_dir, city_code), + test_files, + ), + position=0, + leave=True, + total=len(test_files), + ): + writer.writerow(test_result) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data/test_generator.py b/jasper/data/test_generator.py new file mode 100644 index 0000000..6e14d5c --- /dev/null +++ b/jasper/data/test_generator.py @@ -0,0 +1,99 @@ +import typer +from pathlib import Path +from .utils import generate_dates, asr_test_writer + +app = typer.Typer() + + +@app.command() +def export_test_reg( + conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True), + data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True), + extraction_key: str = "Cities", + dump_dir: Path = Path("./data/asr_data"), + dump_file: Path = Path("ui_dump.json"), + manifest_file: Path = Path("manifest.json"), + test_file: Path = Path("asr_test.reg"), +): + from .utils import ( + ExtendedPath, + asr_manifest_reader, + gcp_transcribe_gen, + parallel_apply, + ) + from ..client import transcribe_gen + from pydub import AudioSegment + from queue import PriorityQueue + + jasper_map = { + "PNRs": 8045, + "Cities": 8046, + "Names": 8047, + "Dates": 8048, + } + # jasper_map = {"PNRs": 8050, "Cities": 8050, "Names": 8050, "Dates": 8050} + transcriber_gcp = gcp_transcribe_gen() + transcriber_trained = transcribe_gen(asr_port=jasper_map[extraction_key]) + transcriber_all_trained = transcribe_gen(asr_port=8050) + transcriber_libri_all_trained = transcribe_gen(asr_port=8051) + + def find_ent(dd, conv_data): + ents = PriorityQueue() + for ent in conv_data: + if ent in dd["text"]: + ents.put((-len(ent), ent)) + return ents.get_nowait()[1] + + def process_data(d): + orig_seg = AudioSegment.from_wav(d["audio_path"]) + jas_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(24000) + gcp_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000) + deepgram_file = Path("/home/shubham/voice_auto/pnrs/wav/") / Path( + d["audio_path"].stem + ".txt" + ) + if deepgram_file.exists(): + d["deepgram"] = "".join( + [s.replace("CHANNEL 0:", "") for s in deepgram_file.read_text().split("\n")] + ) + else: + d["deepgram"] = 'Not Found' + d["audio_path"] = str(d["audio_path"]) + d["gcp_transcript"] = transcriber_gcp(gcp_seg.raw_data) + d["jasper_trained"] = transcriber_trained(jas_seg.raw_data) + d["jasper_all"] = transcriber_all_trained(jas_seg.raw_data) + d["jasper_libri"] = transcriber_libri_all_trained(jas_seg.raw_data) + return d + + conv_data = ExtendedPath(conv_src).read_json() + conv_data["Dates"] = generate_dates() + + dump_data_path = dump_dir / Path(data_name) / dump_file + ui_dump_data = ExtendedPath(dump_data_path).read_json()["data"] + ui_dump_map = {i["utterance_id"]: i for i in ui_dump_data} + manifest_path = dump_dir / Path(data_name) / manifest_file + test_points = list(asr_manifest_reader(manifest_path)) + test_data_objs = [{**(ui_dump_map[t["audio_path"].stem]), **t} for t in test_points] + test_data = parallel_apply(process_data, test_data_objs) + # test_data = [process_data(t) for t in test_data_objs] + test_path = dump_dir / Path(data_name) / test_file + + def dd_gen(dump_data): + for dd in dump_data: + ent = find_ent(dd, conv_data[extraction_key]) + dd["entity"] = ent + if ent: + yield dd + + asr_test_writer(test_path, dd_gen(test_data)) + # for i, b in enumerate(batch(test_data, 1)): + # test_fname = Path(f"{test_file.stem}_{i}.reg") + # test_path = dump_dir / Path(data_name) / test_fname + # asr_test_writer(test_path, dd_gen(test_data)) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data/utils.py b/jasper/data/utils.py index 1e90f00..1f81a98 100644 --- a/jasper/data/utils.py +++ b/jasper/data/utils.py @@ -7,6 +7,7 @@ from itertools import product from functools import partial from math import floor from uuid import uuid4 +from urllib.parse import urlsplit from concurrent.futures import ThreadPoolExecutor import numpy as np @@ -99,6 +100,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F "data": [], } data_funcs = [] + transcriber_gcp = gcp_transcribe_gen() transcriber_pretrained = transcribe_gen(asr_port=8044) with asr_manifest.open("w") as mf: print(f"writing manifest to {asr_manifest}") @@ -115,6 +117,8 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F rel_pnr_path, ): pretrained_result = transcriber_pretrained(aud_seg.raw_data) + gcp_seg = aud_seg.set_frame_rate(16000) + gcp_result = transcriber_gcp(gcp_seg.raw_data) pretrained_wer = word_error_rate([transcript], [pretrained_result]) wav_plot_path = ( dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png") @@ -130,6 +134,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F "spoken": transcript, "caller": caller_name, "utterance_id": fname, + "gcp_asr": gcp_result, "pretrained_asr": pretrained_result, "pretrained_wer": pretrained_wer, "plot_path": str(wav_plot_path), @@ -194,6 +199,32 @@ def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source): mf.write(manifest) +def asr_test_writer(out_file_path: Path, source): + def dd_str(dd, idx): + path = dd["audio_filepath"] + # dur = dd["duration"] + # return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n" + return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n" + + res_file = out_file_path.with_suffix(".result.json") + with out_file_path.open("w") as of: + print(f"opening {out_file_path} for writing test") + results = [] + idx = 0 + for ui_dd in source: + results.append(ui_dd) + out_str = dd_str(ui_dd, idx) + of.write(out_str) + idx += 1 + of.write("DO_HANGUP\n") + ExtendedPath(res_file).write_json(results) + + +def batch(iterable, n=1): + ls = len(iterable) + return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)] + + class ExtendedPath(type(Path())): """docstring for ExtendedPath.""" @@ -278,6 +309,181 @@ def generate_dates(): return [dm for d, m in product(days, months) for dm in canon_vars(d, m)] +def get_call_logs(call_obj, s3, call_meta_dir): + meta_s3_uri = call_obj["DataURI"] + s3_event_url_p = urlsplit(meta_s3_uri) + saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name) + if not saved_meta_path.exists(): + print(f"downloading : {saved_meta_path} from {meta_s3_uri}") + s3.download_file( + s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path) + ) + call_metas = json.load(saved_meta_path.open()) + return call_metas + + +def gcp_transcribe_gen(): + from google.cloud import speech_v1 + from google.cloud.speech_v1 import enums + + # import io + client = speech_v1.SpeechClient() + # local_file_path = 'resources/brooklyn_bridge.raw' + + # The language of the supplied audio + language_code = "en-US" + model = "phone_call" + + # Sample rate in Hertz of the audio data sent + sample_rate_hertz = 16000 + + # Encoding of audio data sent. This sample sets this explicitly. + # This field is optional for FLAC and WAV audio formats. + encoding = enums.RecognitionConfig.AudioEncoding.LINEAR16 + config = { + "language_code": language_code, + "sample_rate_hertz": sample_rate_hertz, + "encoding": encoding, + "model": model, + "enable_automatic_punctuation": True, + "max_alternatives": 10, + "enable_word_time_offsets": True, # used to detect start and end time of utterances + "speech_contexts": [ + { + "phrases": [ + "$OOV_CLASS_ALPHANUMERIC_SEQUENCE", + "$OOV_CLASS_DIGIT_SEQUENCE", + "$TIME", + "$YEAR", + ] + }, + { + "phrases": [ + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + "K", + "L", + "M", + "N", + "O", + "P", + "Q", + "R", + "S", + "T", + "U", + "V", + "W", + "X", + "Y", + "Z", + ] + }, + { + "phrases": [ + "PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE", + "my PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE", + "my PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE", + "PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE", + "It's $OOV_CLASS_ALPHANUMERIC_SEQUENCE", + "$OOV_CLASS_ALPHANUMERIC_SEQUENCE is my PNR", + ] + }, + {"phrases": ["my name is"]}, + {"phrases": ["Number $ORDINAL", "Numeral $ORDINAL"]}, + { + "phrases": [ + "John Smith", + "Carina Hu", + "Travis Lim", + "Marvin Tan", + "Samuel Tan", + "Dawn Mathew", + "Dawn", + "Mathew", + ] + }, + { + "phrases": [ + "Beijing", + "Tokyo", + "London", + "19 August", + "7 October", + "11 December", + "17 September", + "19th August", + "7th October", + "11th December", + "17th September", + "ABC123", + "KWXUNP", + "XLU5K1", + "WL2JV6", + "KBS651", + ] + }, + { + "phrases": [ + "first flight", + "second flight", + "third flight", + "first option", + "second option", + "third option", + "first one", + "second one", + "third one", + ] + }, + ], + "metadata": { + "industry_naics_code_of_audio": 481111, + "interaction_type": enums.RecognitionMetadata.InteractionType.PHONE_CALL, + }, + } + + def sample_recognize(content): + """ + Transcribe a short audio file using synchronous speech recognition + + Args: + local_file_path Path to local audio file, e.g. /path/audio.wav + """ + + # with io.open(local_file_path, "rb") as f: + # content = f.read() + audio = {"content": content} + + response = client.recognize(config, audio) + for result in response.results: + # First alternative is the most probable result + return "/".join([alt.transcript for alt in result.alternatives]) + # print(u"Transcript: {}".format(alternative.transcript)) + return "" + + return sample_recognize + + +def parallel_apply(fn, iterable, workers=8): + with ThreadPoolExecutor(max_workers=workers) as exe: + print(f"parallelly applying {fn}") + return [ + res + for res in tqdm( + exe.map(fn, iterable), position=0, leave=True, total=len(iterable) + ) + ] + + def main(): for c in random_pnr_generator(): print(c) diff --git a/jasper/data/validation/process.py b/jasper/data/validation/process.py index aa21ba4..32b81ea 100644 --- a/jasper/data/validation/process.py +++ b/jasper/data/validation/process.py @@ -271,7 +271,7 @@ def split_extract( dump_file: Path = Path("ui_dump.json"), manifest_file: Path = Path("manifest.json"), corrections_file: str = typer.Option("corrections.json", show_default=True), - conv_data_path: Path = Path("./data/conv_data.json"), + conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True), extraction_type: ExtractionType = ExtractionType.all, ): import shutil @@ -299,10 +299,16 @@ def split_extract( asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen)) ui_data_path = dump_dir / Path(data_name) / dump_file - ui_data = json.load(ui_data_path.open())["data"] + orig_ui_data = ExtendedPath(ui_data_path).read_json() + ui_data = orig_ui_data["data"] file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data} extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data)) - ExtendedPath(dest_ui_path).write_json(extracted_ui_data) + final_data = [] + for i, d in enumerate(extracted_ui_data): + d['real_idx'] = i + final_data.append(d) + orig_ui_data['data'] = final_data + ExtendedPath(dest_ui_path).write_json(orig_ui_data) if corrections_file: dest_correction_path = dest_data_dir / corrections_file @@ -331,7 +337,7 @@ def update_corrections( manifest_file: Path = Path("manifest.json"), corrections_file: Path = Path("corrections.json"), ui_dump_file: Path = Path("ui_dump.json"), - skip_incorrect: bool = True, + skip_incorrect: bool = typer.Option(True, show_default=True), ): data_manifest_path = dump_dir / Path(data_name) / manifest_file corrections_path = dump_dir / Path(data_name) / corrections_file diff --git a/setup.py b/setup.py index d1c85d4..fddfb80 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ extra_requirements = { "streamlit==0.58.0", "natural==0.2.0", "stringcase==1.2.0", + "google-cloud-speech~=1.3.1", ] # "train": [ # "torchaudio==0.5.0", @@ -66,12 +67,14 @@ setup( "jasper_data_tts_generate = jasper.data.tts_generator:main", "jasper_data_conv_generate = jasper.data.conv_generator:main", "jasper_data_nlu_generate = jasper.data.nlu_generator:main", + "jasper_data_test_generate = jasper.data.test_generator:main", "jasper_data_call_recycle = jasper.data.call_recycler:main", "jasper_data_asr_recycle = jasper.data.asr_recycler:main", "jasper_data_rev_recycle = jasper.data.rev_recycler:main", "jasper_data_server = jasper.data.server:main", "jasper_data_validation = jasper.data.validation.process:main", "jasper_data_preprocess = jasper.data.process:main", + "jasper_data_slu_evaluate = jasper.data.slu_evaluator:main", ] }, zip_safe=False,