Compare commits

...

3 Commits

Author SHA1 Message Date
Malar Kannan 069392d098 1. added a test generator and slu evaluator
2. ui dump now include gcp results
3. showing default option for more args validation process commands
2020-06-29 14:24:56 +05:30
Malar Kannan 515e9c1037 1. split extract all data types in one shot with --extraction-type all flag
2. add notes about diffing split extracted and original data
3. add a nlu conv generator to generate conv data based on nlu utterances and entities
4. add task uid support for dumping corrections
5. abstracted generate date fn
2020-06-25 11:03:09 +05:30
Malar Kannan e76ccda5dd 1. fix update-correction to use ui_dump instead of manifest
2. update training params no of checkpoints on chpk frequency
2020-06-19 14:16:04 +05:30
12 changed files with 731 additions and 234 deletions

5
Notes.md Normal file
View File

@ -0,0 +1,5 @@
> Diff after splitting based on type
```
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
```

View File

@ -147,7 +147,7 @@ def analyze(
import matplotlib.pyplot as plt
import matplotlib
from tqdm import tqdm
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll, get_call_logs
from pydub import AudioSegment
from natural.date import compress
@ -170,18 +170,6 @@ def analyze(
call_logs = yaml.load(call_logs_file.read_text())
def get_call_meta(call_obj):
meta_s3_uri = call_obj["DataURI"]
s3_event_url_p = urlsplit(meta_s3_uri)
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
if not saved_meta_path.exists():
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
s3.download_file(
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
)
call_metas = json.load(saved_meta_path.open())
return call_metas
def gen_ev_fev_timedelta(fev):
fev_p = Timestamp()
fev_p.FromJsonString(fev["CreatedTS"])
@ -283,7 +271,7 @@ def analyze(
return spoken
def process_call(call_obj):
call_meta = get_call_meta(call_obj)
call_meta = get_call_logs(call_obj, s3, call_meta_dir)
call_events = call_meta["Events"]
def is_writer_uri_event(ev):

View File

@ -1,8 +1,6 @@
import typer
from pathlib import Path
from random import randrange
from itertools import product
from math import floor
from .utils import generate_dates
app = typer.Typer()
@ -16,46 +14,7 @@ def export_conv_json(
conv_data = ExtendedPath(conv_src).read_json()
days = [i for i in range(1, 32)]
months = [
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
]
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
def ordinal(n):
return "%d%s" % (
n,
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
)
def canon_vars(d, m):
return [
ordinal(d) + " " + m,
m + " " + ordinal(d),
ordinal(d) + " of " + m,
m + " the " + ordinal(d),
str(d) + " " + m,
m + " " + str(d),
]
day_months = [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
conv_data["dates"] = day_months
def dates_data_gen():
i = randrange(len(day_months))
return day_months[i]
conv_data["dates"] = generate_dates()
ExtendedPath(conv_dest).write_json(conv_data)

View File

@ -0,0 +1,98 @@
from pathlib import Path
import typer
import pandas as pd
from ruamel.yaml import YAML
from itertools import product
from .utils import generate_dates
app = typer.Typer()
def unique_entity_list(entity_template_tags, entity_data):
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return list(unique_entity_set)
def nlu_entity_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
yaml = YAML()
nlu_data = yaml.load(nlu_data_file.read_text())
for cf in nlu_data["csv_files"]:
data = pd.read_csv(cf["fname"])
for et in cf["entities"]:
entity_name = et["name"]
entity_template_tags = et["tags"]
if "filter" in et:
entity_data = data[data[cf["filter_key"]] == et["filter"]]
else:
entity_data = data
yield entity_name, entity_template_tags, entity_data
def nlu_samples_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
yaml = YAML()
nlu_data = yaml.load(nlu_data_file.read_text())
sm = {s["name"]: s for s in nlu_data["samples_per_entity"]}
return sm
@app.command()
def compute_unique_nlu_stats(
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
):
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
nlu_data_file
):
entity_count = len(unique_entity_list(entity_template_tags, entity_data))
print(f"{entity_name}\t{entity_count}")
def replace_entity(tmpl, value, tags):
result = tmpl
for t in tags:
result = result.replace(t, value)
return result
@app.command()
def export_nlu_conv_json(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
):
from .utils import ExtendedPath
from random import sample
entity_samples = nlu_samples_reader(nlu_data_file)
conv_data = ExtendedPath(conv_src).read_json()
conv_data["Dates"] = generate_dates()
result_dict = {}
data_count = 0
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
nlu_data_file
):
entity_variants = sample(conv_data[entity_name], entity_samples[entity_name]["test_size"])
unique_entites = unique_entity_list(entity_template_tags, entity_data)
# sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
result_dict[entity_name] = []
for val in entity_variants:
sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
for tmpl in sample_entites:
result = replace_entity(tmpl, val, entity_template_tags)
result_dict[entity_name].append(result)
data_count += 1
print(f"Total of {data_count} variants generated")
ExtendedPath(conv_dest).write_json(result_dict)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -23,7 +23,7 @@ def fixate_data(dataset_path: Path):
@app.command()
def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
reader_list = []
abs_manifest_path = Path("abs_manifest.json")
for dataset_path in src_dataset_paths:

View File

@ -0,0 +1,180 @@
import typer
from pathlib import Path
import json
# from .utils import generate_dates, asr_test_writer
app = typer.Typer()
def run_test(reg_path, coll, s3, call_meta_dir, city_code, test_path):
from time import sleep
import subprocess
from .utils import ExtendedPath, get_call_logs
coll.delete_many({"CallID": test_path.name})
# test_path = dump_dir / data_name / test_file
# "../saas_reg/regression/run.sh -f data/asr_data/call_upwork_test_cnd_cities/asr_test.reg"
test_output = subprocess.run(
["/bin/bash", "-c", f"{str(reg_path)} --addr [::]:15400 -f {str(test_path)}"]
)
if test_output.returncode != 0:
print("Error running test {test_file}")
return
def get_meta():
call_meta = coll.find_one({"CallID": test_path.name})
if call_meta:
return call_meta
else:
sleep(2)
return get_meta()
call_meta = get_meta()
call_logs = get_call_logs(call_meta, s3, call_meta_dir)
call_events = call_logs["Events"]
test_data_path = test_path.with_suffix(".result.json")
test_data = ExtendedPath(test_data_path).read_json()
def is_final_asr_event_or_spoken(ev):
pld = json.loads(ev["Payload"])
return (
pld["AsrResult"]["Results"][0]["IsFinal"]
if ev["Type"] == "ASR_RESULT"
else False
)
def is_test_event(ev):
return (
ev["Author"] == "NLU"
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
) and (ev["Type"] != "DEBUG")
test_evs = list(filter(is_test_event, call_events))
if len(test_evs) == 2:
try:
asr_payload = test_evs[0]["Payload"]
asr_result = json.loads(asr_payload)["AsrResult"]["Results"][0]
alt_tscripts = [alt["Transcript"] for alt in asr_result["Alternatives"]]
gcp_result = "|".join(alt_tscripts)
entity_asr = asr_result["AsrDynamicResults"][0]["Candidate"]["Transcript"]
nlu_payload = test_evs[1]["Payload"]
nlu_result_payload = json.loads(nlu_payload)["NluResults"]
entity = test_data[0]["entity"]
text = test_data[0]["text"]
audio_filepath = test_data[0]["audio_filepath"]
pretrained_asr = test_data[0]["pretrained_asr"]
nlu_entity = list(json.loads(nlu_result_payload)["Entities"].values())[0]
asr_entity = city_code[entity] if entity in city_code else "UNKNOWN"
entities_match = asr_entity == nlu_entity
result = "Success" if entities_match else "Fail"
return {
"expected_entity": entity,
"text": text,
"audio_filepath": audio_filepath,
"pretrained_asr": pretrained_asr,
"entity_asr": entity_asr,
"google_asr": gcp_result,
"nlu_result": nlu_result_payload,
"asr_entity": asr_entity,
"nlu_entity": nlu_entity,
"result": result,
}
except Exception:
return {
"expected_entity": test_data[0]["entity"],
"text": test_data[0]["text"],
"audio_filepath": test_data[0]["audio_filepath"],
"pretrained_asr": test_data[0]["pretrained_asr"],
"entity_asr": "",
"google_asr": "",
"nlu_result": "",
"asr_entity": "",
"nlu_entity": "",
"result": "Error",
}
else:
return {
"expected_entity": test_data[0]["entity"],
"text": test_data[0]["text"],
"audio_filepath": test_data[0]["audio_filepath"],
"pretrained_asr": test_data[0]["pretrained_asr"],
"entity_asr": "",
"google_asr": "",
"nlu_result": "",
"asr_entity": "",
"nlu_entity": "",
"result": "Empty",
}
@app.command()
def evaluate_slu(
# conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
# extraction_key: str = "Cities",
dump_dir: Path = Path("./data/asr_data"),
call_meta_dir: Path = Path("./data/call_metas"),
test_file_pref: str = "asr_test",
mongo_uri: str = typer.Option(
"mongodb://localhost:27017/test.calls", show_default=True
),
test_results: Path = Path("./data/results.csv"),
airport_codes: Path = Path("./airports_code.csv"),
reg_path: Path = Path("../saas_reg/regression/run.sh"),
test_id: str = "5ef481f27031edf6910e94e0",
):
# import json
from .utils import get_mongo_coll
import pandas as pd
import boto3
from concurrent.futures import ThreadPoolExecutor
from functools import partial
# import subprocess
# from time import sleep
import csv
from tqdm import tqdm
s3 = boto3.client("s3")
df = pd.read_csv(airport_codes)[["iata", "city"]]
city_code = pd.Series(df["iata"].values, index=df["city"]).to_dict()
test_files = list((dump_dir / data_name).glob(test_file_pref + "*.reg"))
coll = get_mongo_coll(mongo_uri)
with test_results.open("w") as csvfile:
fieldnames = [
"expected_entity",
"text",
"audio_filepath",
"pretrained_asr",
"entity_asr",
"google_asr",
"nlu_result",
"asr_entity",
"nlu_entity",
"result",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
with ThreadPoolExecutor(max_workers=8) as exe:
print("starting all loading tasks")
for test_result in tqdm(
exe.map(
partial(run_test, reg_path, coll, s3, call_meta_dir, city_code),
test_files,
),
position=0,
leave=True,
total=len(test_files),
):
writer.writerow(test_result)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,99 @@
import typer
from pathlib import Path
from .utils import generate_dates, asr_test_writer
app = typer.Typer()
@app.command()
def export_test_reg(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
extraction_key: str = "Cities",
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
test_file: Path = Path("asr_test.reg"),
):
from .utils import (
ExtendedPath,
asr_manifest_reader,
gcp_transcribe_gen,
parallel_apply,
)
from ..client import transcribe_gen
from pydub import AudioSegment
from queue import PriorityQueue
jasper_map = {
"PNRs": 8045,
"Cities": 8046,
"Names": 8047,
"Dates": 8048,
}
# jasper_map = {"PNRs": 8050, "Cities": 8050, "Names": 8050, "Dates": 8050}
transcriber_gcp = gcp_transcribe_gen()
transcriber_trained = transcribe_gen(asr_port=jasper_map[extraction_key])
transcriber_all_trained = transcribe_gen(asr_port=8050)
transcriber_libri_all_trained = transcribe_gen(asr_port=8051)
def find_ent(dd, conv_data):
ents = PriorityQueue()
for ent in conv_data:
if ent in dd["text"]:
ents.put((-len(ent), ent))
return ents.get_nowait()[1]
def process_data(d):
orig_seg = AudioSegment.from_wav(d["audio_path"])
jas_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(24000)
gcp_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
deepgram_file = Path("/home/shubham/voice_auto/pnrs/wav/") / Path(
d["audio_path"].stem + ".txt"
)
if deepgram_file.exists():
d["deepgram"] = "".join(
[s.replace("CHANNEL 0:", "") for s in deepgram_file.read_text().split("\n")]
)
else:
d["deepgram"] = 'Not Found'
d["audio_path"] = str(d["audio_path"])
d["gcp_transcript"] = transcriber_gcp(gcp_seg.raw_data)
d["jasper_trained"] = transcriber_trained(jas_seg.raw_data)
d["jasper_all"] = transcriber_all_trained(jas_seg.raw_data)
d["jasper_libri"] = transcriber_libri_all_trained(jas_seg.raw_data)
return d
conv_data = ExtendedPath(conv_src).read_json()
conv_data["Dates"] = generate_dates()
dump_data_path = dump_dir / Path(data_name) / dump_file
ui_dump_data = ExtendedPath(dump_data_path).read_json()["data"]
ui_dump_map = {i["utterance_id"]: i for i in ui_dump_data}
manifest_path = dump_dir / Path(data_name) / manifest_file
test_points = list(asr_manifest_reader(manifest_path))
test_data_objs = [{**(ui_dump_map[t["audio_path"].stem]), **t} for t in test_points]
test_data = parallel_apply(process_data, test_data_objs)
# test_data = [process_data(t) for t in test_data_objs]
test_path = dump_dir / Path(data_name) / test_file
def dd_gen(dump_data):
for dd in dump_data:
ent = find_ent(dd, conv_data[extraction_key])
dd["entity"] = ent
if ent:
yield dd
asr_test_writer(test_path, dd_gen(test_data))
# for i, b in enumerate(batch(test_data, 1)):
# test_fname = Path(f"{test_file.stem}_{i}.reg")
# test_path = dump_dir / Path(data_name) / test_fname
# asr_test_writer(test_path, dd_gen(test_data))
def main():
app()
if __name__ == "__main__":
main()

View File

@ -1,92 +0,0 @@
import pandas as pd
def compute_pnr_name_city():
data = pd.read_csv("./customer_utterance_processing/customer_provide_answer.csv")
def unique_pnr_count():
pnr_data = data[data["Input.Answer"] == "ZZZZZZ"]
unique_pnr_set = {
t
for n in range(1, 5)
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
if "ZZZZZZ" in t
}
return len(unique_pnr_set)
def unique_name_count():
pnr_data = data[data["Input.Answer"] == "John Doe"]
unique_pnr_set = {
t
for n in range(1, 5)
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
if "John Doe" in t
}
return len(unique_pnr_set)
def unique_city_count():
pnr_data = data[data["Input.Answer"] == "Heathrow Airport"]
unique_pnr_set = {
t
for n in range(1, 5)
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
if "Heathrow Airport" in t
}
return len(unique_pnr_set)
def unique_entity_count(entity_template_tags):
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
entity_data = data
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return len(unique_entity_set)
print('PNR', unique_pnr_count())
print('Name', unique_name_count())
print('City', unique_city_count())
print('Payment', unique_entity_count(['KPay', 'ZPay', 'Credit Card']))
def compute_date():
entity_template_tags = ['27 january', 'December 18']
data = pd.read_csv("./customer_utterance_processing/customer_provide_departure.csv")
# data.sample(10)
def unique_entity_count(entity_template_tags):
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
entity_data = data
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return len(unique_entity_set)
print('Date', unique_entity_count(entity_template_tags))
def compute_option():
entity_template_tag = 'third'
data = pd.read_csv("./customer_utterance_processing/customer_provide_flight_selection.csv")
def unique_entity_count():
entity_data = data[data['Input.Prompt'] == entity_template_tag]
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if entity_template_tag in t
}
return len(unique_entity_set)
print('Option', unique_entity_count())
compute_pnr_name_city()
compute_date()
compute_option()

View File

@ -1,13 +1,18 @@
import numpy as np
import wave
import io
import os
import json
import wave
from pathlib import Path
from itertools import product
from functools import partial
from math import floor
from uuid import uuid4
from urllib.parse import urlsplit
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import pymongo
from slugify import slugify
from uuid import uuid4
from num2words import num2words
from jasper.client import transcribe_gen
from nemo.collections.asr.metrics import word_error_rate
@ -15,8 +20,6 @@ import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm
from functools import partial
from concurrent.futures import ThreadPoolExecutor
def manifest_str(path, dur, text):
@ -59,6 +62,10 @@ def alnum_to_asr_tokens(text):
return ("".join(num_tokens)).lower()
def tscript_uuid_fname(transcript):
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
@ -67,7 +74,7 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for transcript, audio_dur, wav_data in asr_data_source:
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
fname = tscript_uuid_fname(transcript)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
rel_pnr_path = audio_file.relative_to(dataset_dir)
@ -93,6 +100,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"data": [],
}
data_funcs = []
transcriber_gcp = gcp_transcribe_gen()
transcriber_pretrained = transcribe_gen(asr_port=8044)
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
@ -109,6 +117,8 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
rel_pnr_path,
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
gcp_seg = aud_seg.set_frame_rate(16000)
gcp_result = transcriber_gcp(gcp_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
@ -124,6 +134,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"spoken": transcript,
"caller": caller_name,
"utterance_id": fname,
"gcp_asr": gcp_result,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
@ -174,7 +185,7 @@ def asr_manifest_reader(data_manifest_path: Path):
pnr_data = [json.loads(v) for v in pnr_jsonl]
for p in pnr_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["chars"] = Path(p["audio_filepath"]).stem
p["text"] = p["text"].strip()
yield p
@ -188,6 +199,32 @@ def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
mf.write(manifest)
def asr_test_writer(out_file_path: Path, source):
def dd_str(dd, idx):
path = dd["audio_filepath"]
# dur = dd["duration"]
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
res_file = out_file_path.with_suffix(".result.json")
with out_file_path.open("w") as of:
print(f"opening {out_file_path} for writing test")
results = []
idx = 0
for ui_dd in source:
results.append(ui_dd)
out_str = dd_str(ui_dd, idx)
of.write(out_str)
idx += 1
of.write("DO_HANGUP\n")
ExtendedPath(res_file).write_json(results)
def batch(iterable, n=1):
ls = len(iterable)
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
class ExtendedPath(type(Path())):
"""docstring for ExtendedPath."""
@ -234,6 +271,219 @@ def plot_seg(wav_plot_path, audio_path):
fig.savefig(wav_plot_f, format="png", dpi=50)
def generate_dates():
days = [i for i in range(1, 32)]
months = [
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
]
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
def ordinal(n):
return "%d%s" % (
n,
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
)
def canon_vars(d, m):
return [
ordinal(d) + " " + m,
m + " " + ordinal(d),
ordinal(d) + " of " + m,
m + " the " + ordinal(d),
str(d) + " " + m,
m + " " + str(d),
]
return [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
def get_call_logs(call_obj, s3, call_meta_dir):
meta_s3_uri = call_obj["DataURI"]
s3_event_url_p = urlsplit(meta_s3_uri)
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
if not saved_meta_path.exists():
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
s3.download_file(
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
)
call_metas = json.load(saved_meta_path.open())
return call_metas
def gcp_transcribe_gen():
from google.cloud import speech_v1
from google.cloud.speech_v1 import enums
# import io
client = speech_v1.SpeechClient()
# local_file_path = 'resources/brooklyn_bridge.raw'
# The language of the supplied audio
language_code = "en-US"
model = "phone_call"
# Sample rate in Hertz of the audio data sent
sample_rate_hertz = 16000
# Encoding of audio data sent. This sample sets this explicitly.
# This field is optional for FLAC and WAV audio formats.
encoding = enums.RecognitionConfig.AudioEncoding.LINEAR16
config = {
"language_code": language_code,
"sample_rate_hertz": sample_rate_hertz,
"encoding": encoding,
"model": model,
"enable_automatic_punctuation": True,
"max_alternatives": 10,
"enable_word_time_offsets": True, # used to detect start and end time of utterances
"speech_contexts": [
{
"phrases": [
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"$OOV_CLASS_DIGIT_SEQUENCE",
"$TIME",
"$YEAR",
]
},
{
"phrases": [
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
]
},
{
"phrases": [
"PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"my PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"my PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"It's $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE is my PNR",
]
},
{"phrases": ["my name is"]},
{"phrases": ["Number $ORDINAL", "Numeral $ORDINAL"]},
{
"phrases": [
"John Smith",
"Carina Hu",
"Travis Lim",
"Marvin Tan",
"Samuel Tan",
"Dawn Mathew",
"Dawn",
"Mathew",
]
},
{
"phrases": [
"Beijing",
"Tokyo",
"London",
"19 August",
"7 October",
"11 December",
"17 September",
"19th August",
"7th October",
"11th December",
"17th September",
"ABC123",
"KWXUNP",
"XLU5K1",
"WL2JV6",
"KBS651",
]
},
{
"phrases": [
"first flight",
"second flight",
"third flight",
"first option",
"second option",
"third option",
"first one",
"second one",
"third one",
]
},
],
"metadata": {
"industry_naics_code_of_audio": 481111,
"interaction_type": enums.RecognitionMetadata.InteractionType.PHONE_CALL,
},
}
def sample_recognize(content):
"""
Transcribe a short audio file using synchronous speech recognition
Args:
local_file_path Path to local audio file, e.g. /path/audio.wav
"""
# with io.open(local_file_path, "rb") as f:
# content = f.read()
audio = {"content": content}
response = client.recognize(config, audio)
for result in response.results:
# First alternative is the most probable result
return "/".join([alt.transcript for alt in result.alternatives])
# print(u"Transcript: {}".format(alternative.transcript))
return ""
return sample_recognize
def parallel_apply(fn, iterable, workers=8):
with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")
return [
res
for res in tqdm(
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
)
]
def main():
for c in random_pnr_generator():
print(c)

View File

@ -11,6 +11,7 @@ from ..utils import (
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
tscript_uuid_fname,
get_mongo_conn,
plot_seg,
)
@ -180,14 +181,16 @@ def task_ui(
@app.command()
def dump_corrections(
task_uid: str,
data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("corrections.json"),
):
dump_path = dump_dir / Path(data_name) / dump_fname
col = get_mongo_conn(col="asr_validation")
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@ -256,49 +259,36 @@ class ExtractionType(str, Enum):
date = "dates"
city = "cities"
name = "names"
all = "all"
@app.command()
def split_extract(
data_name: str = typer.Option("call_alphanum", show_default=True),
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
# dump_dir: Path = Path("./data/valiation_data"),
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
conv_data_path: Path = Path("./data/conv_data.json"),
extraction_type: ExtractionType = ExtractionType.date,
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
extraction_type: ExtractionType = ExtractionType.all,
):
import shutil
def get_conv_data(cdp):
from itertools import product
data_manifest_path = dump_dir / Path(data_name) / manifest_file
conv_data = ExtendedPath(conv_data_path).read_json()
conv_data = json.load(cdp.open())
days = [str(i) for i in range(1, 32)]
months = conv_data["months"]
day_months = {d + " " + m for d, m in product(days, months)}
return {
"cities": set(conv_data["cities"]),
"names": set(conv_data["names"]),
"dates": day_months,
}
dest_data_name = data_name + "_" + extraction_type.value
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
conv_data = get_conv_data(conv_data_path)
extraction_vals = conv_data[extraction_type.value]
def extract_data_of_type(extraction_key):
extraction_vals = conv_data[extraction_key]
dest_data_name = data_name + "_" + extraction_key.lower()
manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = manifest_dir / Path(dest_data_name)
dest_data_dir = dump_dir / Path(dest_data_name)
dest_data_dir.mkdir(exist_ok=True, parents=True)
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file
dest_ui_dir = dump_dir / Path(dest_data_name)
dest_ui_dir.mkdir(exist_ok=True, parents=True)
dest_ui_path = dest_ui_dir / dump_file
dest_correction_path = dest_ui_dir / corrections_file
dest_ui_path = dest_data_dir / dump_file
def extract_manifest(mg):
for m in mg:
@ -309,14 +299,21 @@ def split_extract(
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
ui_data_path = dump_dir / Path(data_name) / dump_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_data = json.load(ui_data_path.open())["data"]
orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
corrections = json.load(corrections_path.open())
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
ExtendedPath(dest_ui_path).write_json(extracted_ui_data)
final_data = []
for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i
final_data.append(d)
orig_ui_data['data'] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
if corrections_file:
dest_correction_path = dest_data_dir / corrections_file
corrections_path = dump_dir / Path(data_name) / corrections_file
corrections = json.load(corrections_path.open())
extracted_corrections = list(
filter(
lambda c: c["code"] in file_ui_map
@ -326,23 +323,29 @@ def split_extract(
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all':
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
extract_data_of_type(extraction_type.value)
@app.command()
def update_corrections(
data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/valiation_data"),
manifest_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
# data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
# corrections_path: Path = Path("./data/valiation_data/corrections.json"),
skip_incorrect: bool = True,
ui_dump_file: Path = Path("ui_dump.json"),
skip_incorrect: bool = typer.Option(True, show_default=True),
):
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
data_manifest_path = dump_dir / Path(data_name) / manifest_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
def correct_manifest(manifest_data_gen, corrections_path):
corrections = json.load(corrections_path.open())
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
@ -355,36 +358,38 @@ def update_corrections(
# for d in manifest_data_gen:
# if d["chars"] in incorrect_set:
# d["audio_path"].unlink()
renamed_set = set()
for d in manifest_data_gen:
if d["chars"] in correct_set:
# renamed_set = set()
for d in ui_data:
if d["utterance_id"] in correct_set:
yield {
"audio_filepath": d["audio_filepath"],
"duration": d["duration"],
"text": d["text"],
}
elif d["chars"] in correction_map:
correct_text = correction_map[d["chars"]]
elif d["utterance_id"] in correction_map:
correct_text = correction_map[d["utterance_id"]]
if skip_incorrect:
print(
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
)
else:
renamed_set.add(correct_text)
new_name = str(Path(correct_text).with_suffix(".wav"))
d["audio_path"].replace(d["audio_path"].with_name(new_name))
orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
yield {
"audio_filepath": new_filepath,
"duration": d["duration"],
"text": alnum_to_asr_tokens(correct_text),
"text": correct_text,
}
else:
orig_audio_path = Path(d["audio_path"])
# don't delete if another correction points to an old file
if d["chars"] not in renamed_set:
d["audio_path"].unlink()
else:
print(f'skipping deletion of correction:{d["chars"]}')
# if d["text"] not in renamed_set:
orig_audio_path.unlink()
# else:
# print(f'skipping deletion of correction:{d["text"]}')
typer.echo(f"Using data manifest:{data_manifest_path}")
dataset_dir = data_manifest_path.parent
@ -393,8 +398,8 @@ def update_corrections(
if not backup_dir.exists():
typer.echo(f"backing up to :{backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir))
manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
# manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
new_data_manifest_path.replace(data_manifest_path)

View File

@ -41,7 +41,7 @@ def parse_args():
work_dir="./train/work",
num_epochs=300,
weight_decay=0.005,
checkpoint_save_freq=200,
checkpoint_save_freq=100,
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
@ -266,6 +266,7 @@ def create_all_dags(args, neural_factory):
folder=neural_factory.checkpoint_dir,
load_from_folder=args.load_dir,
step_freq=args.checkpoint_save_freq,
checkpoints_to_keep=30,
)
callbacks = [train_callback, chpt_callback]

View File

@ -39,6 +39,7 @@ extra_requirements = {
"streamlit==0.58.0",
"natural==0.2.0",
"stringcase==1.2.0",
"google-cloud-speech~=1.3.1",
]
# "train": [
# "torchaudio==0.5.0",
@ -65,12 +66,15 @@ setup(
"jasper_trainer = jasper.training.cli:main",
"jasper_data_tts_generate = jasper.data.tts_generator:main",
"jasper_data_conv_generate = jasper.data.conv_generator:main",
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
"jasper_data_test_generate = jasper.data.test_generator:main",
"jasper_data_call_recycle = jasper.data.call_recycler:main",
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
"jasper_data_server = jasper.data.server:main",
"jasper_data_validation = jasper.data.validation.process:main",
"jasper_data_preprocess = jasper.data.process:main",
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
]
},
zip_safe=False,