1. added a test generator and slu evaluator

2. ui dump now include gcp results
3. showing default option for more args validation process commands
Malar Kannan 2020-06-29 14:24:56 +05:30
parent 515e9c1037
commit 069392d098
6 changed files with 500 additions and 18 deletions

View File

@ -147,7 +147,7 @@ def analyze(
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib import matplotlib
from tqdm import tqdm from tqdm import tqdm
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll, get_call_logs
from pydub import AudioSegment from pydub import AudioSegment
from natural.date import compress from natural.date import compress
@ -170,18 +170,6 @@ def analyze(
call_logs = yaml.load(call_logs_file.read_text()) call_logs = yaml.load(call_logs_file.read_text())
def get_call_meta(call_obj):
meta_s3_uri = call_obj["DataURI"]
s3_event_url_p = urlsplit(meta_s3_uri)
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
if not saved_meta_path.exists():
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
s3.download_file(
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
)
call_metas = json.load(saved_meta_path.open())
return call_metas
def gen_ev_fev_timedelta(fev): def gen_ev_fev_timedelta(fev):
fev_p = Timestamp() fev_p = Timestamp()
fev_p.FromJsonString(fev["CreatedTS"]) fev_p.FromJsonString(fev["CreatedTS"])
@ -283,7 +271,7 @@ def analyze(
return spoken return spoken
def process_call(call_obj): def process_call(call_obj):
call_meta = get_call_meta(call_obj) call_meta = get_call_logs(call_obj, s3, call_meta_dir)
call_events = call_meta["Events"] call_events = call_meta["Events"]
def is_writer_uri_event(ev): def is_writer_uri_event(ev):

View File

@ -0,0 +1,180 @@
import typer
from pathlib import Path
import json
# from .utils import generate_dates, asr_test_writer
app = typer.Typer()
def run_test(reg_path, coll, s3, call_meta_dir, city_code, test_path):
from time import sleep
import subprocess
from .utils import ExtendedPath, get_call_logs
coll.delete_many({"CallID": test_path.name})
# test_path = dump_dir / data_name / test_file
# "../saas_reg/regression/run.sh -f data/asr_data/call_upwork_test_cnd_cities/asr_test.reg"
test_output = subprocess.run(
["/bin/bash", "-c", f"{str(reg_path)} --addr [::]:15400 -f {str(test_path)}"]
)
if test_output.returncode != 0:
print("Error running test {test_file}")
return
def get_meta():
call_meta = coll.find_one({"CallID": test_path.name})
if call_meta:
return call_meta
else:
sleep(2)
return get_meta()
call_meta = get_meta()
call_logs = get_call_logs(call_meta, s3, call_meta_dir)
call_events = call_logs["Events"]
test_data_path = test_path.with_suffix(".result.json")
test_data = ExtendedPath(test_data_path).read_json()
def is_final_asr_event_or_spoken(ev):
pld = json.loads(ev["Payload"])
return (
pld["AsrResult"]["Results"][0]["IsFinal"]
if ev["Type"] == "ASR_RESULT"
else False
)
def is_test_event(ev):
return (
ev["Author"] == "NLU"
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
) and (ev["Type"] != "DEBUG")
test_evs = list(filter(is_test_event, call_events))
if len(test_evs) == 2:
try:
asr_payload = test_evs[0]["Payload"]
asr_result = json.loads(asr_payload)["AsrResult"]["Results"][0]
alt_tscripts = [alt["Transcript"] for alt in asr_result["Alternatives"]]
gcp_result = "|".join(alt_tscripts)
entity_asr = asr_result["AsrDynamicResults"][0]["Candidate"]["Transcript"]
nlu_payload = test_evs[1]["Payload"]
nlu_result_payload = json.loads(nlu_payload)["NluResults"]
entity = test_data[0]["entity"]
text = test_data[0]["text"]
audio_filepath = test_data[0]["audio_filepath"]
pretrained_asr = test_data[0]["pretrained_asr"]
nlu_entity = list(json.loads(nlu_result_payload)["Entities"].values())[0]
asr_entity = city_code[entity] if entity in city_code else "UNKNOWN"
entities_match = asr_entity == nlu_entity
result = "Success" if entities_match else "Fail"
return {
"expected_entity": entity,
"text": text,
"audio_filepath": audio_filepath,
"pretrained_asr": pretrained_asr,
"entity_asr": entity_asr,
"google_asr": gcp_result,
"nlu_result": nlu_result_payload,
"asr_entity": asr_entity,
"nlu_entity": nlu_entity,
"result": result,
}
except Exception:
return {
"expected_entity": test_data[0]["entity"],
"text": test_data[0]["text"],
"audio_filepath": test_data[0]["audio_filepath"],
"pretrained_asr": test_data[0]["pretrained_asr"],
"entity_asr": "",
"google_asr": "",
"nlu_result": "",
"asr_entity": "",
"nlu_entity": "",
"result": "Error",
}
else:
return {
"expected_entity": test_data[0]["entity"],
"text": test_data[0]["text"],
"audio_filepath": test_data[0]["audio_filepath"],
"pretrained_asr": test_data[0]["pretrained_asr"],
"entity_asr": "",
"google_asr": "",
"nlu_result": "",
"asr_entity": "",
"nlu_entity": "",
"result": "Empty",
}
@app.command()
def evaluate_slu(
# conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
# extraction_key: str = "Cities",
dump_dir: Path = Path("./data/asr_data"),
call_meta_dir: Path = Path("./data/call_metas"),
test_file_pref: str = "asr_test",
mongo_uri: str = typer.Option(
"mongodb://localhost:27017/test.calls", show_default=True
),
test_results: Path = Path("./data/results.csv"),
airport_codes: Path = Path("./airports_code.csv"),
reg_path: Path = Path("../saas_reg/regression/run.sh"),
test_id: str = "5ef481f27031edf6910e94e0",
):
# import json
from .utils import get_mongo_coll
import pandas as pd
import boto3
from concurrent.futures import ThreadPoolExecutor
from functools import partial
# import subprocess
# from time import sleep
import csv
from tqdm import tqdm
s3 = boto3.client("s3")
df = pd.read_csv(airport_codes)[["iata", "city"]]
city_code = pd.Series(df["iata"].values, index=df["city"]).to_dict()
test_files = list((dump_dir / data_name).glob(test_file_pref + "*.reg"))
coll = get_mongo_coll(mongo_uri)
with test_results.open("w") as csvfile:
fieldnames = [
"expected_entity",
"text",
"audio_filepath",
"pretrained_asr",
"entity_asr",
"google_asr",
"nlu_result",
"asr_entity",
"nlu_entity",
"result",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
with ThreadPoolExecutor(max_workers=8) as exe:
print("starting all loading tasks")
for test_result in tqdm(
exe.map(
partial(run_test, reg_path, coll, s3, call_meta_dir, city_code),
test_files,
),
position=0,
leave=True,
total=len(test_files),
):
writer.writerow(test_result)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,99 @@
import typer
from pathlib import Path
from .utils import generate_dates, asr_test_writer
app = typer.Typer()
@app.command()
def export_test_reg(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
extraction_key: str = "Cities",
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
test_file: Path = Path("asr_test.reg"),
):
from .utils import (
ExtendedPath,
asr_manifest_reader,
gcp_transcribe_gen,
parallel_apply,
)
from ..client import transcribe_gen
from pydub import AudioSegment
from queue import PriorityQueue
jasper_map = {
"PNRs": 8045,
"Cities": 8046,
"Names": 8047,
"Dates": 8048,
}
# jasper_map = {"PNRs": 8050, "Cities": 8050, "Names": 8050, "Dates": 8050}
transcriber_gcp = gcp_transcribe_gen()
transcriber_trained = transcribe_gen(asr_port=jasper_map[extraction_key])
transcriber_all_trained = transcribe_gen(asr_port=8050)
transcriber_libri_all_trained = transcribe_gen(asr_port=8051)
def find_ent(dd, conv_data):
ents = PriorityQueue()
for ent in conv_data:
if ent in dd["text"]:
ents.put((-len(ent), ent))
return ents.get_nowait()[1]
def process_data(d):
orig_seg = AudioSegment.from_wav(d["audio_path"])
jas_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(24000)
gcp_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
deepgram_file = Path("/home/shubham/voice_auto/pnrs/wav/") / Path(
d["audio_path"].stem + ".txt"
)
if deepgram_file.exists():
d["deepgram"] = "".join(
[s.replace("CHANNEL 0:", "") for s in deepgram_file.read_text().split("\n")]
)
else:
d["deepgram"] = 'Not Found'
d["audio_path"] = str(d["audio_path"])
d["gcp_transcript"] = transcriber_gcp(gcp_seg.raw_data)
d["jasper_trained"] = transcriber_trained(jas_seg.raw_data)
d["jasper_all"] = transcriber_all_trained(jas_seg.raw_data)
d["jasper_libri"] = transcriber_libri_all_trained(jas_seg.raw_data)
return d
conv_data = ExtendedPath(conv_src).read_json()
conv_data["Dates"] = generate_dates()
dump_data_path = dump_dir / Path(data_name) / dump_file
ui_dump_data = ExtendedPath(dump_data_path).read_json()["data"]
ui_dump_map = {i["utterance_id"]: i for i in ui_dump_data}
manifest_path = dump_dir / Path(data_name) / manifest_file
test_points = list(asr_manifest_reader(manifest_path))
test_data_objs = [{**(ui_dump_map[t["audio_path"].stem]), **t} for t in test_points]
test_data = parallel_apply(process_data, test_data_objs)
# test_data = [process_data(t) for t in test_data_objs]
test_path = dump_dir / Path(data_name) / test_file
def dd_gen(dump_data):
for dd in dump_data:
ent = find_ent(dd, conv_data[extraction_key])
dd["entity"] = ent
if ent:
yield dd
asr_test_writer(test_path, dd_gen(test_data))
# for i, b in enumerate(batch(test_data, 1)):
# test_fname = Path(f"{test_file.stem}_{i}.reg")
# test_path = dump_dir / Path(data_name) / test_fname
# asr_test_writer(test_path, dd_gen(test_data))
def main():
app()
if __name__ == "__main__":
main()

View File

@ -7,6 +7,7 @@ from itertools import product
from functools import partial from functools import partial
from math import floor from math import floor
from uuid import uuid4 from uuid import uuid4
from urllib.parse import urlsplit
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import numpy as np import numpy as np
@ -99,6 +100,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"data": [], "data": [],
} }
data_funcs = [] data_funcs = []
transcriber_gcp = gcp_transcribe_gen()
transcriber_pretrained = transcribe_gen(asr_port=8044) transcriber_pretrained = transcribe_gen(asr_port=8044)
with asr_manifest.open("w") as mf: with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}") print(f"writing manifest to {asr_manifest}")
@ -115,6 +117,8 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
rel_pnr_path, rel_pnr_path,
): ):
pretrained_result = transcriber_pretrained(aud_seg.raw_data) pretrained_result = transcriber_pretrained(aud_seg.raw_data)
gcp_seg = aud_seg.set_frame_rate(16000)
gcp_result = transcriber_gcp(gcp_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result]) pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = ( wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png") dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
@ -130,6 +134,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"spoken": transcript, "spoken": transcript,
"caller": caller_name, "caller": caller_name,
"utterance_id": fname, "utterance_id": fname,
"gcp_asr": gcp_result,
"pretrained_asr": pretrained_result, "pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer, "pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path), "plot_path": str(wav_plot_path),
@ -194,6 +199,32 @@ def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
mf.write(manifest) mf.write(manifest)
def asr_test_writer(out_file_path: Path, source):
def dd_str(dd, idx):
path = dd["audio_filepath"]
# dur = dd["duration"]
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
res_file = out_file_path.with_suffix(".result.json")
with out_file_path.open("w") as of:
print(f"opening {out_file_path} for writing test")
results = []
idx = 0
for ui_dd in source:
results.append(ui_dd)
out_str = dd_str(ui_dd, idx)
of.write(out_str)
idx += 1
of.write("DO_HANGUP\n")
ExtendedPath(res_file).write_json(results)
def batch(iterable, n=1):
ls = len(iterable)
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
class ExtendedPath(type(Path())): class ExtendedPath(type(Path())):
"""docstring for ExtendedPath.""" """docstring for ExtendedPath."""
@ -278,6 +309,181 @@ def generate_dates():
return [dm for d, m in product(days, months) for dm in canon_vars(d, m)] return [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
def get_call_logs(call_obj, s3, call_meta_dir):
meta_s3_uri = call_obj["DataURI"]
s3_event_url_p = urlsplit(meta_s3_uri)
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
if not saved_meta_path.exists():
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
s3.download_file(
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
)
call_metas = json.load(saved_meta_path.open())
return call_metas
def gcp_transcribe_gen():
from google.cloud import speech_v1
from google.cloud.speech_v1 import enums
# import io
client = speech_v1.SpeechClient()
# local_file_path = 'resources/brooklyn_bridge.raw'
# The language of the supplied audio
language_code = "en-US"
model = "phone_call"
# Sample rate in Hertz of the audio data sent
sample_rate_hertz = 16000
# Encoding of audio data sent. This sample sets this explicitly.
# This field is optional for FLAC and WAV audio formats.
encoding = enums.RecognitionConfig.AudioEncoding.LINEAR16
config = {
"language_code": language_code,
"sample_rate_hertz": sample_rate_hertz,
"encoding": encoding,
"model": model,
"enable_automatic_punctuation": True,
"max_alternatives": 10,
"enable_word_time_offsets": True, # used to detect start and end time of utterances
"speech_contexts": [
{
"phrases": [
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"$OOV_CLASS_DIGIT_SEQUENCE",
"$TIME",
"$YEAR",
]
},
{
"phrases": [
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
]
},
{
"phrases": [
"PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"my PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"my PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"It's $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE is my PNR",
]
},
{"phrases": ["my name is"]},
{"phrases": ["Number $ORDINAL", "Numeral $ORDINAL"]},
{
"phrases": [
"John Smith",
"Carina Hu",
"Travis Lim",
"Marvin Tan",
"Samuel Tan",
"Dawn Mathew",
"Dawn",
"Mathew",
]
},
{
"phrases": [
"Beijing",
"Tokyo",
"London",
"19 August",
"7 October",
"11 December",
"17 September",
"19th August",
"7th October",
"11th December",
"17th September",
"ABC123",
"KWXUNP",
"XLU5K1",
"WL2JV6",
"KBS651",
]
},
{
"phrases": [
"first flight",
"second flight",
"third flight",
"first option",
"second option",
"third option",
"first one",
"second one",
"third one",
]
},
],
"metadata": {
"industry_naics_code_of_audio": 481111,
"interaction_type": enums.RecognitionMetadata.InteractionType.PHONE_CALL,
},
}
def sample_recognize(content):
"""
Transcribe a short audio file using synchronous speech recognition
Args:
local_file_path Path to local audio file, e.g. /path/audio.wav
"""
# with io.open(local_file_path, "rb") as f:
# content = f.read()
audio = {"content": content}
response = client.recognize(config, audio)
for result in response.results:
# First alternative is the most probable result
return "/".join([alt.transcript for alt in result.alternatives])
# print(u"Transcript: {}".format(alternative.transcript))
return ""
return sample_recognize
def parallel_apply(fn, iterable, workers=8):
with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")
return [
res
for res in tqdm(
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
)
]
def main(): def main():
for c in random_pnr_generator(): for c in random_pnr_generator():
print(c) print(c)

View File

@ -271,7 +271,7 @@ def split_extract(
dump_file: Path = Path("ui_dump.json"), dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"), manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True), corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = Path("./data/conv_data.json"), conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
extraction_type: ExtractionType = ExtractionType.all, extraction_type: ExtractionType = ExtractionType.all,
): ):
import shutil import shutil
@ -299,10 +299,16 @@ def split_extract(
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen)) asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
ui_data_path = dump_dir / Path(data_name) / dump_file ui_data_path = dump_dir / Path(data_name) / dump_file
ui_data = json.load(ui_data_path.open())["data"] orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data} file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data)) extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
ExtendedPath(dest_ui_path).write_json(extracted_ui_data) final_data = []
for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i
final_data.append(d)
orig_ui_data['data'] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file: if corrections_file:
dest_correction_path = dest_data_dir / corrections_file dest_correction_path = dest_data_dir / corrections_file
@ -331,7 +337,7 @@ def update_corrections(
manifest_file: Path = Path("manifest.json"), manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"), corrections_file: Path = Path("corrections.json"),
ui_dump_file: Path = Path("ui_dump.json"), ui_dump_file: Path = Path("ui_dump.json"),
skip_incorrect: bool = True, skip_incorrect: bool = typer.Option(True, show_default=True),
): ):
data_manifest_path = dump_dir / Path(data_name) / manifest_file data_manifest_path = dump_dir / Path(data_name) / manifest_file
corrections_path = dump_dir / Path(data_name) / corrections_file corrections_path = dump_dir / Path(data_name) / corrections_file

View File

@ -39,6 +39,7 @@ extra_requirements = {
"streamlit==0.58.0", "streamlit==0.58.0",
"natural==0.2.0", "natural==0.2.0",
"stringcase==1.2.0", "stringcase==1.2.0",
"google-cloud-speech~=1.3.1",
] ]
# "train": [ # "train": [
# "torchaudio==0.5.0", # "torchaudio==0.5.0",
@ -66,12 +67,14 @@ setup(
"jasper_data_tts_generate = jasper.data.tts_generator:main", "jasper_data_tts_generate = jasper.data.tts_generator:main",
"jasper_data_conv_generate = jasper.data.conv_generator:main", "jasper_data_conv_generate = jasper.data.conv_generator:main",
"jasper_data_nlu_generate = jasper.data.nlu_generator:main", "jasper_data_nlu_generate = jasper.data.nlu_generator:main",
"jasper_data_test_generate = jasper.data.test_generator:main",
"jasper_data_call_recycle = jasper.data.call_recycler:main", "jasper_data_call_recycle = jasper.data.call_recycler:main",
"jasper_data_asr_recycle = jasper.data.asr_recycler:main", "jasper_data_asr_recycle = jasper.data.asr_recycler:main",
"jasper_data_rev_recycle = jasper.data.rev_recycler:main", "jasper_data_rev_recycle = jasper.data.rev_recycler:main",
"jasper_data_server = jasper.data.server:main", "jasper_data_server = jasper.data.server:main",
"jasper_data_validation = jasper.data.validation.process:main", "jasper_data_validation = jasper.data.validation.process:main",
"jasper_data_preprocess = jasper.data.process:main", "jasper_data_preprocess = jasper.data.process:main",
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
] ]
}, },
zip_safe=False, zip_safe=False,