Compare commits

..

No commits in common. "069392d09893106bc8b66a3a436af60224fbda4a" and "000853b600975fa4b1050f0bec875b7db2a89c34" have entirely different histories.

12 changed files with 234 additions and 731 deletions

View File

@ -1,5 +0,0 @@
> Diff after splitting based on type
```
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
```

View File

@ -147,7 +147,7 @@ def analyze(
import matplotlib.pyplot as plt
import matplotlib
from tqdm import tqdm
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll, get_call_logs
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll
from pydub import AudioSegment
from natural.date import compress
@ -170,6 +170,18 @@ def analyze(
call_logs = yaml.load(call_logs_file.read_text())
def get_call_meta(call_obj):
meta_s3_uri = call_obj["DataURI"]
s3_event_url_p = urlsplit(meta_s3_uri)
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
if not saved_meta_path.exists():
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
s3.download_file(
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
)
call_metas = json.load(saved_meta_path.open())
return call_metas
def gen_ev_fev_timedelta(fev):
fev_p = Timestamp()
fev_p.FromJsonString(fev["CreatedTS"])
@ -271,7 +283,7 @@ def analyze(
return spoken
def process_call(call_obj):
call_meta = get_call_logs(call_obj, s3, call_meta_dir)
call_meta = get_call_meta(call_obj)
call_events = call_meta["Events"]
def is_writer_uri_event(ev):

View File

@ -1,6 +1,8 @@
import typer
from pathlib import Path
from .utils import generate_dates
from random import randrange
from itertools import product
from math import floor
app = typer.Typer()
@ -14,7 +16,46 @@ def export_conv_json(
conv_data = ExtendedPath(conv_src).read_json()
conv_data["dates"] = generate_dates()
days = [i for i in range(1, 32)]
months = [
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
]
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
def ordinal(n):
return "%d%s" % (
n,
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
)
def canon_vars(d, m):
return [
ordinal(d) + " " + m,
m + " " + ordinal(d),
ordinal(d) + " of " + m,
m + " the " + ordinal(d),
str(d) + " " + m,
m + " " + str(d),
]
day_months = [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
conv_data["dates"] = day_months
def dates_data_gen():
i = randrange(len(day_months))
return day_months[i]
ExtendedPath(conv_dest).write_json(conv_data)

View File

@ -1,98 +0,0 @@
from pathlib import Path
import typer
import pandas as pd
from ruamel.yaml import YAML
from itertools import product
from .utils import generate_dates
app = typer.Typer()
def unique_entity_list(entity_template_tags, entity_data):
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return list(unique_entity_set)
def nlu_entity_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
yaml = YAML()
nlu_data = yaml.load(nlu_data_file.read_text())
for cf in nlu_data["csv_files"]:
data = pd.read_csv(cf["fname"])
for et in cf["entities"]:
entity_name = et["name"]
entity_template_tags = et["tags"]
if "filter" in et:
entity_data = data[data[cf["filter_key"]] == et["filter"]]
else:
entity_data = data
yield entity_name, entity_template_tags, entity_data
def nlu_samples_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
yaml = YAML()
nlu_data = yaml.load(nlu_data_file.read_text())
sm = {s["name"]: s for s in nlu_data["samples_per_entity"]}
return sm
@app.command()
def compute_unique_nlu_stats(
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
):
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
nlu_data_file
):
entity_count = len(unique_entity_list(entity_template_tags, entity_data))
print(f"{entity_name}\t{entity_count}")
def replace_entity(tmpl, value, tags):
result = tmpl
for t in tags:
result = result.replace(t, value)
return result
@app.command()
def export_nlu_conv_json(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
):
from .utils import ExtendedPath
from random import sample
entity_samples = nlu_samples_reader(nlu_data_file)
conv_data = ExtendedPath(conv_src).read_json()
conv_data["Dates"] = generate_dates()
result_dict = {}
data_count = 0
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
nlu_data_file
):
entity_variants = sample(conv_data[entity_name], entity_samples[entity_name]["test_size"])
unique_entites = unique_entity_list(entity_template_tags, entity_data)
# sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
result_dict[entity_name] = []
for val in entity_variants:
sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
for tmpl in sample_entites:
result = replace_entity(tmpl, val, entity_template_tags)
result_dict[entity_name].append(result)
data_count += 1
print(f"Total of {data_count} variants generated")
ExtendedPath(conv_dest).write_json(result_dict)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -23,7 +23,7 @@ def fixate_data(dataset_path: Path):
@app.command()
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
reader_list = []
abs_manifest_path = Path("abs_manifest.json")
for dataset_path in src_dataset_paths:

View File

@ -1,180 +0,0 @@
import typer
from pathlib import Path
import json
# from .utils import generate_dates, asr_test_writer
app = typer.Typer()
def run_test(reg_path, coll, s3, call_meta_dir, city_code, test_path):
from time import sleep
import subprocess
from .utils import ExtendedPath, get_call_logs
coll.delete_many({"CallID": test_path.name})
# test_path = dump_dir / data_name / test_file
# "../saas_reg/regression/run.sh -f data/asr_data/call_upwork_test_cnd_cities/asr_test.reg"
test_output = subprocess.run(
["/bin/bash", "-c", f"{str(reg_path)} --addr [::]:15400 -f {str(test_path)}"]
)
if test_output.returncode != 0:
print("Error running test {test_file}")
return
def get_meta():
call_meta = coll.find_one({"CallID": test_path.name})
if call_meta:
return call_meta
else:
sleep(2)
return get_meta()
call_meta = get_meta()
call_logs = get_call_logs(call_meta, s3, call_meta_dir)
call_events = call_logs["Events"]
test_data_path = test_path.with_suffix(".result.json")
test_data = ExtendedPath(test_data_path).read_json()
def is_final_asr_event_or_spoken(ev):
pld = json.loads(ev["Payload"])
return (
pld["AsrResult"]["Results"][0]["IsFinal"]
if ev["Type"] == "ASR_RESULT"
else False
)
def is_test_event(ev):
return (
ev["Author"] == "NLU"
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
) and (ev["Type"] != "DEBUG")
test_evs = list(filter(is_test_event, call_events))
if len(test_evs) == 2:
try:
asr_payload = test_evs[0]["Payload"]
asr_result = json.loads(asr_payload)["AsrResult"]["Results"][0]
alt_tscripts = [alt["Transcript"] for alt in asr_result["Alternatives"]]
gcp_result = "|".join(alt_tscripts)
entity_asr = asr_result["AsrDynamicResults"][0]["Candidate"]["Transcript"]
nlu_payload = test_evs[1]["Payload"]
nlu_result_payload = json.loads(nlu_payload)["NluResults"]
entity = test_data[0]["entity"]
text = test_data[0]["text"]
audio_filepath = test_data[0]["audio_filepath"]
pretrained_asr = test_data[0]["pretrained_asr"]
nlu_entity = list(json.loads(nlu_result_payload)["Entities"].values())[0]
asr_entity = city_code[entity] if entity in city_code else "UNKNOWN"
entities_match = asr_entity == nlu_entity
result = "Success" if entities_match else "Fail"
return {
"expected_entity": entity,
"text": text,
"audio_filepath": audio_filepath,
"pretrained_asr": pretrained_asr,
"entity_asr": entity_asr,
"google_asr": gcp_result,
"nlu_result": nlu_result_payload,
"asr_entity": asr_entity,
"nlu_entity": nlu_entity,
"result": result,
}
except Exception:
return {
"expected_entity": test_data[0]["entity"],
"text": test_data[0]["text"],
"audio_filepath": test_data[0]["audio_filepath"],
"pretrained_asr": test_data[0]["pretrained_asr"],
"entity_asr": "",
"google_asr": "",
"nlu_result": "",
"asr_entity": "",
"nlu_entity": "",
"result": "Error",
}
else:
return {
"expected_entity": test_data[0]["entity"],
"text": test_data[0]["text"],
"audio_filepath": test_data[0]["audio_filepath"],
"pretrained_asr": test_data[0]["pretrained_asr"],
"entity_asr": "",
"google_asr": "",
"nlu_result": "",
"asr_entity": "",
"nlu_entity": "",
"result": "Empty",
}
@app.command()
def evaluate_slu(
# conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
# extraction_key: str = "Cities",
dump_dir: Path = Path("./data/asr_data"),
call_meta_dir: Path = Path("./data/call_metas"),
test_file_pref: str = "asr_test",
mongo_uri: str = typer.Option(
"mongodb://localhost:27017/test.calls", show_default=True
),
test_results: Path = Path("./data/results.csv"),
airport_codes: Path = Path("./airports_code.csv"),
reg_path: Path = Path("../saas_reg/regression/run.sh"),
test_id: str = "5ef481f27031edf6910e94e0",
):
# import json
from .utils import get_mongo_coll
import pandas as pd
import boto3
from concurrent.futures import ThreadPoolExecutor
from functools import partial
# import subprocess
# from time import sleep
import csv
from tqdm import tqdm
s3 = boto3.client("s3")
df = pd.read_csv(airport_codes)[["iata", "city"]]
city_code = pd.Series(df["iata"].values, index=df["city"]).to_dict()
test_files = list((dump_dir / data_name).glob(test_file_pref + "*.reg"))
coll = get_mongo_coll(mongo_uri)
with test_results.open("w") as csvfile:
fieldnames = [
"expected_entity",
"text",
"audio_filepath",
"pretrained_asr",
"entity_asr",
"google_asr",
"nlu_result",
"asr_entity",
"nlu_entity",
"result",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
with ThreadPoolExecutor(max_workers=8) as exe:
print("starting all loading tasks")
for test_result in tqdm(
exe.map(
partial(run_test, reg_path, coll, s3, call_meta_dir, city_code),
test_files,
),
position=0,
leave=True,
total=len(test_files),
):
writer.writerow(test_result)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -1,99 +0,0 @@
import typer
from pathlib import Path
from .utils import generate_dates, asr_test_writer
app = typer.Typer()
@app.command()
def export_test_reg(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
extraction_key: str = "Cities",
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
test_file: Path = Path("asr_test.reg"),
):
from .utils import (
ExtendedPath,
asr_manifest_reader,
gcp_transcribe_gen,
parallel_apply,
)
from ..client import transcribe_gen
from pydub import AudioSegment
from queue import PriorityQueue
jasper_map = {
"PNRs": 8045,
"Cities": 8046,
"Names": 8047,
"Dates": 8048,
}
# jasper_map = {"PNRs": 8050, "Cities": 8050, "Names": 8050, "Dates": 8050}
transcriber_gcp = gcp_transcribe_gen()
transcriber_trained = transcribe_gen(asr_port=jasper_map[extraction_key])
transcriber_all_trained = transcribe_gen(asr_port=8050)
transcriber_libri_all_trained = transcribe_gen(asr_port=8051)
def find_ent(dd, conv_data):
ents = PriorityQueue()
for ent in conv_data:
if ent in dd["text"]:
ents.put((-len(ent), ent))
return ents.get_nowait()[1]
def process_data(d):
orig_seg = AudioSegment.from_wav(d["audio_path"])
jas_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(24000)
gcp_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
deepgram_file = Path("/home/shubham/voice_auto/pnrs/wav/") / Path(
d["audio_path"].stem + ".txt"
)
if deepgram_file.exists():
d["deepgram"] = "".join(
[s.replace("CHANNEL 0:", "") for s in deepgram_file.read_text().split("\n")]
)
else:
d["deepgram"] = 'Not Found'
d["audio_path"] = str(d["audio_path"])
d["gcp_transcript"] = transcriber_gcp(gcp_seg.raw_data)
d["jasper_trained"] = transcriber_trained(jas_seg.raw_data)
d["jasper_all"] = transcriber_all_trained(jas_seg.raw_data)
d["jasper_libri"] = transcriber_libri_all_trained(jas_seg.raw_data)
return d
conv_data = ExtendedPath(conv_src).read_json()
conv_data["Dates"] = generate_dates()
dump_data_path = dump_dir / Path(data_name) / dump_file
ui_dump_data = ExtendedPath(dump_data_path).read_json()["data"]
ui_dump_map = {i["utterance_id"]: i for i in ui_dump_data}
manifest_path = dump_dir / Path(data_name) / manifest_file
test_points = list(asr_manifest_reader(manifest_path))
test_data_objs = [{**(ui_dump_map[t["audio_path"].stem]), **t} for t in test_points]
test_data = parallel_apply(process_data, test_data_objs)
# test_data = [process_data(t) for t in test_data_objs]
test_path = dump_dir / Path(data_name) / test_file
def dd_gen(dump_data):
for dd in dump_data:
ent = find_ent(dd, conv_data[extraction_key])
dd["entity"] = ent
if ent:
yield dd
asr_test_writer(test_path, dd_gen(test_data))
# for i, b in enumerate(batch(test_data, 1)):
# test_fname = Path(f"{test_file.stem}_{i}.reg")
# test_path = dump_dir / Path(data_name) / test_fname
# asr_test_writer(test_path, dd_gen(test_data))
def main():
app()
if __name__ == "__main__":
main()

92
jasper/data/unique_nlu.py Normal file
View File

@ -0,0 +1,92 @@
import pandas as pd
def compute_pnr_name_city():
data = pd.read_csv("./customer_utterance_processing/customer_provide_answer.csv")
def unique_pnr_count():
pnr_data = data[data["Input.Answer"] == "ZZZZZZ"]
unique_pnr_set = {
t
for n in range(1, 5)
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
if "ZZZZZZ" in t
}
return len(unique_pnr_set)
def unique_name_count():
pnr_data = data[data["Input.Answer"] == "John Doe"]
unique_pnr_set = {
t
for n in range(1, 5)
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
if "John Doe" in t
}
return len(unique_pnr_set)
def unique_city_count():
pnr_data = data[data["Input.Answer"] == "Heathrow Airport"]
unique_pnr_set = {
t
for n in range(1, 5)
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
if "Heathrow Airport" in t
}
return len(unique_pnr_set)
def unique_entity_count(entity_template_tags):
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
entity_data = data
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return len(unique_entity_set)
print('PNR', unique_pnr_count())
print('Name', unique_name_count())
print('City', unique_city_count())
print('Payment', unique_entity_count(['KPay', 'ZPay', 'Credit Card']))
def compute_date():
entity_template_tags = ['27 january', 'December 18']
data = pd.read_csv("./customer_utterance_processing/customer_provide_departure.csv")
# data.sample(10)
def unique_entity_count(entity_template_tags):
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
entity_data = data
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return len(unique_entity_set)
print('Date', unique_entity_count(entity_template_tags))
def compute_option():
entity_template_tag = 'third'
data = pd.read_csv("./customer_utterance_processing/customer_provide_flight_selection.csv")
def unique_entity_count():
entity_data = data[data['Input.Prompt'] == entity_template_tag]
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if entity_template_tag in t
}
return len(unique_entity_set)
print('Option', unique_entity_count())
compute_pnr_name_city()
compute_date()
compute_option()

View File

@ -1,18 +1,13 @@
import numpy as np
import wave
import io
import os
import json
import wave
from pathlib import Path
from itertools import product
from functools import partial
from math import floor
from uuid import uuid4
from urllib.parse import urlsplit
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import pymongo
from slugify import slugify
from uuid import uuid4
from num2words import num2words
from jasper.client import transcribe_gen
from nemo.collections.asr.metrics import word_error_rate
@ -20,6 +15,8 @@ import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm
from functools import partial
from concurrent.futures import ThreadPoolExecutor
def manifest_str(path, dur, text):
@ -62,10 +59,6 @@ def alnum_to_asr_tokens(text):
return ("".join(num_tokens)).lower()
def tscript_uuid_fname(transcript):
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
dataset_dir = output_dir / Path(dataset_name)
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
@ -74,7 +67,7 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
for transcript, audio_dur, wav_data in asr_data_source:
fname = tscript_uuid_fname(transcript)
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
audio_file.write_bytes(wav_data)
rel_pnr_path = audio_file.relative_to(dataset_dir)
@ -100,7 +93,6 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"data": [],
}
data_funcs = []
transcriber_gcp = gcp_transcribe_gen()
transcriber_pretrained = transcribe_gen(asr_port=8044)
with asr_manifest.open("w") as mf:
print(f"writing manifest to {asr_manifest}")
@ -117,8 +109,6 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
rel_pnr_path,
):
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
gcp_seg = aud_seg.set_frame_rate(16000)
gcp_result = transcriber_gcp(gcp_seg.raw_data)
pretrained_wer = word_error_rate([transcript], [pretrained_result])
wav_plot_path = (
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
@ -134,7 +124,6 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
"spoken": transcript,
"caller": caller_name,
"utterance_id": fname,
"gcp_asr": gcp_result,
"pretrained_asr": pretrained_result,
"pretrained_wer": pretrained_wer,
"plot_path": str(wav_plot_path),
@ -185,7 +174,7 @@ def asr_manifest_reader(data_manifest_path: Path):
pnr_data = [json.loads(v) for v in pnr_jsonl]
for p in pnr_data:
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
p["text"] = p["text"].strip()
p["chars"] = Path(p["audio_filepath"]).stem
yield p
@ -199,32 +188,6 @@ def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
mf.write(manifest)
def asr_test_writer(out_file_path: Path, source):
def dd_str(dd, idx):
path = dd["audio_filepath"]
# dur = dd["duration"]
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
res_file = out_file_path.with_suffix(".result.json")
with out_file_path.open("w") as of:
print(f"opening {out_file_path} for writing test")
results = []
idx = 0
for ui_dd in source:
results.append(ui_dd)
out_str = dd_str(ui_dd, idx)
of.write(out_str)
idx += 1
of.write("DO_HANGUP\n")
ExtendedPath(res_file).write_json(results)
def batch(iterable, n=1):
ls = len(iterable)
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
class ExtendedPath(type(Path())):
"""docstring for ExtendedPath."""
@ -271,219 +234,6 @@ def plot_seg(wav_plot_path, audio_path):
fig.savefig(wav_plot_f, format="png", dpi=50)
def generate_dates():
days = [i for i in range(1, 32)]
months = [
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
]
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
def ordinal(n):
return "%d%s" % (
n,
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
)
def canon_vars(d, m):
return [
ordinal(d) + " " + m,
m + " " + ordinal(d),
ordinal(d) + " of " + m,
m + " the " + ordinal(d),
str(d) + " " + m,
m + " " + str(d),
]
return [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
def get_call_logs(call_obj, s3, call_meta_dir):
meta_s3_uri = call_obj["DataURI"]
s3_event_url_p = urlsplit(meta_s3_uri)
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
if not saved_meta_path.exists():
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
s3.download_file(
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
)
call_metas = json.load(saved_meta_path.open())
return call_metas
def gcp_transcribe_gen():
from google.cloud import speech_v1
from google.cloud.speech_v1 import enums
# import io
client = speech_v1.SpeechClient()
# local_file_path = 'resources/brooklyn_bridge.raw'
# The language of the supplied audio
language_code = "en-US"
model = "phone_call"
# Sample rate in Hertz of the audio data sent
sample_rate_hertz = 16000
# Encoding of audio data sent. This sample sets this explicitly.
# This field is optional for FLAC and WAV audio formats.
encoding = enums.RecognitionConfig.AudioEncoding.LINEAR16
config = {
"language_code": language_code,
"sample_rate_hertz": sample_rate_hertz,
"encoding": encoding,
"model": model,
"enable_automatic_punctuation": True,
"max_alternatives": 10,
"enable_word_time_offsets": True, # used to detect start and end time of utterances
"speech_contexts": [
{
"phrases": [
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"$OOV_CLASS_DIGIT_SEQUENCE",
"$TIME",
"$YEAR",
]
},
{
"phrases": [
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
]
},
{
"phrases": [
"PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"my PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"my PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"It's $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE is my PNR",
]
},
{"phrases": ["my name is"]},
{"phrases": ["Number $ORDINAL", "Numeral $ORDINAL"]},
{
"phrases": [
"John Smith",
"Carina Hu",
"Travis Lim",
"Marvin Tan",
"Samuel Tan",
"Dawn Mathew",
"Dawn",
"Mathew",
]
},
{
"phrases": [
"Beijing",
"Tokyo",
"London",
"19 August",
"7 October",
"11 December",
"17 September",
"19th August",
"7th October",
"11th December",
"17th September",
"ABC123",
"KWXUNP",
"XLU5K1",
"WL2JV6",
"KBS651",
]
},
{
"phrases": [
"first flight",
"second flight",
"third flight",
"first option",
"second option",
"third option",
"first one",
"second one",
"third one",
]
},
],
"metadata": {
"industry_naics_code_of_audio": 481111,
"interaction_type": enums.RecognitionMetadata.InteractionType.PHONE_CALL,
},
}
def sample_recognize(content):
"""
Transcribe a short audio file using synchronous speech recognition
Args:
local_file_path Path to local audio file, e.g. /path/audio.wav
"""
# with io.open(local_file_path, "rb") as f:
# content = f.read()
audio = {"content": content}
response = client.recognize(config, audio)
for result in response.results:
# First alternative is the most probable result
return "/".join([alt.transcript for alt in result.alternatives])
# print(u"Transcript: {}".format(alternative.transcript))
return ""
return sample_recognize
def parallel_apply(fn, iterable, workers=8):
with ThreadPoolExecutor(max_workers=workers) as exe:
print(f"parallelly applying {fn}")
return [
res
for res in tqdm(
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
)
]
def main():
for c in random_pnr_generator():
print(c)

View File

@ -11,7 +11,6 @@ from ..utils import (
ExtendedPath,
asr_manifest_reader,
asr_manifest_writer,
tscript_uuid_fname,
get_mongo_conn,
plot_seg,
)
@ -181,16 +180,14 @@ def task_ui(
@app.command()
def dump_corrections(
task_uid: str,
data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("corrections.json"),
):
dump_path = dump_dir / Path(data_name) / dump_fname
col = get_mongo_conn(col="asr_validation")
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@ -259,93 +256,93 @@ class ExtractionType(str, Enum):
date = "dates"
city = "cities"
name = "names"
all = "all"
@app.command()
def split_extract(
data_name: str = typer.Option("call_alphanum", show_default=True),
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
# dump_dir: Path = Path("./data/valiation_data"),
dump_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/valiation_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
extraction_type: ExtractionType = ExtractionType.all,
corrections_file: Path = Path("corrections.json"),
conv_data_path: Path = Path("./data/conv_data.json"),
extraction_type: ExtractionType = ExtractionType.date,
):
import shutil
data_manifest_path = dump_dir / Path(data_name) / manifest_file
conv_data = ExtendedPath(conv_data_path).read_json()
def get_conv_data(cdp):
from itertools import product
def extract_data_of_type(extraction_key):
extraction_vals = conv_data[extraction_key]
dest_data_name = data_name + "_" + extraction_key.lower()
conv_data = json.load(cdp.open())
days = [str(i) for i in range(1, 32)]
months = conv_data["months"]
day_months = {d + " " + m for d, m in product(days, months)}
return {
"cities": set(conv_data["cities"]),
"names": set(conv_data["names"]),
"dates": day_months,
}
manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = dump_dir / Path(dest_data_name)
dest_data_dir.mkdir(exist_ok=True, parents=True)
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file
dest_ui_path = dest_data_dir / dump_file
dest_data_name = data_name + "_" + extraction_type.value
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
conv_data = get_conv_data(conv_data_path)
extraction_vals = conv_data[extraction_type.value]
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
yield m
manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = manifest_dir / Path(dest_data_name)
dest_data_dir.mkdir(exist_ok=True, parents=True)
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file
dest_ui_dir = dump_dir / Path(dest_data_name)
dest_ui_dir.mkdir(exist_ok=True, parents=True)
dest_ui_path = dest_ui_dir / dump_file
dest_correction_path = dest_ui_dir / corrections_file
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
yield m
ui_data_path = dump_dir / Path(data_name) / dump_file
orig_ui_data = ExtendedPath(ui_data_path).read_json()
ui_data = orig_ui_data["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
final_data = []
for i, d in enumerate(extracted_ui_data):
d['real_idx'] = i
final_data.append(d)
orig_ui_data['data'] = final_data
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
if corrections_file:
dest_correction_path = dest_data_dir / corrections_file
corrections_path = dump_dir / Path(data_name) / corrections_file
corrections = json.load(corrections_path.open())
extracted_corrections = list(
filter(
lambda c: c["code"] in file_ui_map
and file_ui_map[c["code"]]["text"] in extraction_vals,
corrections,
)
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
ui_data_path = dump_dir / Path(data_name) / dump_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_data = json.load(ui_data_path.open())["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
corrections = json.load(corrections_path.open())
if extraction_type.value == 'all':
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
extract_data_of_type(extraction_type.value)
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
ExtendedPath(dest_ui_path).write_json(extracted_ui_data)
extracted_corrections = list(
filter(
lambda c: c["code"] in file_ui_map
and file_ui_map[c["code"]]["text"] in extraction_vals,
corrections,
)
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
@app.command()
def update_corrections(
data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_dir: Path = Path("./data/valiation_data"),
manifest_dir: Path = Path("./data/asr_data"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
ui_dump_file: Path = Path("ui_dump.json"),
skip_incorrect: bool = typer.Option(True, show_default=True),
# data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
# corrections_path: Path = Path("./data/valiation_data/corrections.json"),
skip_incorrect: bool = True,
):
data_manifest_path = dump_dir / Path(data_name) / manifest_file
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
def correct_manifest(ui_dump_path, corrections_path):
corrections = ExtendedPath(corrections_path).read_json()
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
def correct_manifest(manifest_data_gen, corrections_path):
corrections = json.load(corrections_path.open())
correct_set = {
c["code"] for c in corrections if c["value"]["status"] == "Correct"
}
@ -358,38 +355,36 @@ def update_corrections(
# for d in manifest_data_gen:
# if d["chars"] in incorrect_set:
# d["audio_path"].unlink()
# renamed_set = set()
for d in ui_data:
if d["utterance_id"] in correct_set:
renamed_set = set()
for d in manifest_data_gen:
if d["chars"] in correct_set:
yield {
"audio_filepath": d["audio_filepath"],
"duration": d["duration"],
"text": d["text"],
}
elif d["utterance_id"] in correction_map:
correct_text = correction_map[d["utterance_id"]]
elif d["chars"] in correction_map:
correct_text = correction_map[d["chars"]]
if skip_incorrect:
print(
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
)
else:
orig_audio_path = Path(d["audio_path"])
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
new_audio_path = orig_audio_path.with_name(new_name)
orig_audio_path.replace(new_audio_path)
renamed_set.add(correct_text)
new_name = str(Path(correct_text).with_suffix(".wav"))
d["audio_path"].replace(d["audio_path"].with_name(new_name))
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
yield {
"audio_filepath": new_filepath,
"duration": d["duration"],
"text": correct_text,
"text": alnum_to_asr_tokens(correct_text),
}
else:
orig_audio_path = Path(d["audio_path"])
# don't delete if another correction points to an old file
# if d["text"] not in renamed_set:
orig_audio_path.unlink()
# else:
# print(f'skipping deletion of correction:{d["text"]}')
if d["chars"] not in renamed_set:
d["audio_path"].unlink()
else:
print(f'skipping deletion of correction:{d["chars"]}')
typer.echo(f"Using data manifest:{data_manifest_path}")
dataset_dir = data_manifest_path.parent
@ -398,8 +393,8 @@ def update_corrections(
if not backup_dir.exists():
typer.echo(f"backing up to :{backup_dir}")
shutil.copytree(str(dataset_dir), str(backup_dir))
# manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
manifest_gen = asr_manifest_reader(data_manifest_path)
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
new_data_manifest_path.replace(data_manifest_path)

View File

@ -41,7 +41,7 @@ def parse_args():
work_dir="./train/work",
num_epochs=300,
weight_decay=0.005,
checkpoint_save_freq=100,
checkpoint_save_freq=200,
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
@ -266,7 +266,6 @@ def create_all_dags(args, neural_factory):
folder=neural_factory.checkpoint_dir,
load_from_folder=args.load_dir,
step_freq=args.checkpoint_save_freq,
checkpoints_to_keep=30,
)
callbacks = [train_callback, chpt_callback]

View File

@ -39,7 +39,6 @@ extra_requirements = {
"streamlit==0.58.0",
"natural==0.2.0",
"stringcase==1.2.0",
"google-cloud-speech~=1.3.1",
]
# "train": [
# "torchaudio==0.5.0",
@ -66,15 +65,12 @@ setup(
"jasper_trainer = jasper.training.cli:main",
"jasper_data_tts_generate = jasper.data.tts_generator:main",
"jasper_data_conv_generate = jasper.data.conv_generator:main",
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
"jasper_data_test_generate = jasper.data.test_generator:main",
"jasper_data_call_recycle = jasper.data.call_recycler:main",
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
"jasper_data_server = jasper.data.server:main",
"jasper_data_validation = jasper.data.validation.process:main",
"jasper_data_preprocess = jasper.data.process:main",
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
]
},
zip_safe=False,