Compare commits
3 Commits
000853b600
...
069392d098
| Author | SHA1 | Date |
|---|---|---|
|
|
069392d098 | |
|
|
515e9c1037 | |
|
|
e76ccda5dd |
|
|
@ -0,0 +1,5 @@
|
||||||
|
|
||||||
|
> Diff after splitting based on type
|
||||||
|
```
|
||||||
|
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
|
||||||
|
```
|
||||||
|
|
@ -147,7 +147,7 @@ def analyze(
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib
|
import matplotlib
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll
|
from .utils import ui_dump_manifest_writer, strip_silence, get_mongo_coll, get_call_logs
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
from natural.date import compress
|
from natural.date import compress
|
||||||
|
|
||||||
|
|
@ -170,18 +170,6 @@ def analyze(
|
||||||
|
|
||||||
call_logs = yaml.load(call_logs_file.read_text())
|
call_logs = yaml.load(call_logs_file.read_text())
|
||||||
|
|
||||||
def get_call_meta(call_obj):
|
|
||||||
meta_s3_uri = call_obj["DataURI"]
|
|
||||||
s3_event_url_p = urlsplit(meta_s3_uri)
|
|
||||||
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
|
|
||||||
if not saved_meta_path.exists():
|
|
||||||
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
|
|
||||||
s3.download_file(
|
|
||||||
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
|
|
||||||
)
|
|
||||||
call_metas = json.load(saved_meta_path.open())
|
|
||||||
return call_metas
|
|
||||||
|
|
||||||
def gen_ev_fev_timedelta(fev):
|
def gen_ev_fev_timedelta(fev):
|
||||||
fev_p = Timestamp()
|
fev_p = Timestamp()
|
||||||
fev_p.FromJsonString(fev["CreatedTS"])
|
fev_p.FromJsonString(fev["CreatedTS"])
|
||||||
|
|
@ -283,7 +271,7 @@ def analyze(
|
||||||
return spoken
|
return spoken
|
||||||
|
|
||||||
def process_call(call_obj):
|
def process_call(call_obj):
|
||||||
call_meta = get_call_meta(call_obj)
|
call_meta = get_call_logs(call_obj, s3, call_meta_dir)
|
||||||
call_events = call_meta["Events"]
|
call_events = call_meta["Events"]
|
||||||
|
|
||||||
def is_writer_uri_event(ev):
|
def is_writer_uri_event(ev):
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
import typer
|
import typer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from random import randrange
|
from .utils import generate_dates
|
||||||
from itertools import product
|
|
||||||
from math import floor
|
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
@ -16,46 +14,7 @@ def export_conv_json(
|
||||||
|
|
||||||
conv_data = ExtendedPath(conv_src).read_json()
|
conv_data = ExtendedPath(conv_src).read_json()
|
||||||
|
|
||||||
days = [i for i in range(1, 32)]
|
conv_data["dates"] = generate_dates()
|
||||||
months = [
|
|
||||||
"January",
|
|
||||||
"February",
|
|
||||||
"March",
|
|
||||||
"April",
|
|
||||||
"May",
|
|
||||||
"June",
|
|
||||||
"July",
|
|
||||||
"August",
|
|
||||||
"September",
|
|
||||||
"October",
|
|
||||||
"November",
|
|
||||||
"December",
|
|
||||||
]
|
|
||||||
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
|
|
||||||
|
|
||||||
def ordinal(n):
|
|
||||||
return "%d%s" % (
|
|
||||||
n,
|
|
||||||
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
|
|
||||||
)
|
|
||||||
|
|
||||||
def canon_vars(d, m):
|
|
||||||
return [
|
|
||||||
ordinal(d) + " " + m,
|
|
||||||
m + " " + ordinal(d),
|
|
||||||
ordinal(d) + " of " + m,
|
|
||||||
m + " the " + ordinal(d),
|
|
||||||
str(d) + " " + m,
|
|
||||||
m + " " + str(d),
|
|
||||||
]
|
|
||||||
|
|
||||||
day_months = [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
|
|
||||||
|
|
||||||
conv_data["dates"] = day_months
|
|
||||||
|
|
||||||
def dates_data_gen():
|
|
||||||
i = randrange(len(day_months))
|
|
||||||
return day_months[i]
|
|
||||||
|
|
||||||
ExtendedPath(conv_dest).write_json(conv_data)
|
ExtendedPath(conv_dest).write_json(conv_data)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import typer
|
||||||
|
import pandas as pd
|
||||||
|
from ruamel.yaml import YAML
|
||||||
|
from itertools import product
|
||||||
|
from .utils import generate_dates
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
def unique_entity_list(entity_template_tags, entity_data):
|
||||||
|
unique_entity_set = {
|
||||||
|
t
|
||||||
|
for n in range(1, 5)
|
||||||
|
for t in entity_data[f"Answer.utterance-{n}"].tolist()
|
||||||
|
if any(et in t for et in entity_template_tags)
|
||||||
|
}
|
||||||
|
return list(unique_entity_set)
|
||||||
|
|
||||||
|
|
||||||
|
def nlu_entity_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
|
||||||
|
yaml = YAML()
|
||||||
|
nlu_data = yaml.load(nlu_data_file.read_text())
|
||||||
|
for cf in nlu_data["csv_files"]:
|
||||||
|
data = pd.read_csv(cf["fname"])
|
||||||
|
for et in cf["entities"]:
|
||||||
|
entity_name = et["name"]
|
||||||
|
entity_template_tags = et["tags"]
|
||||||
|
if "filter" in et:
|
||||||
|
entity_data = data[data[cf["filter_key"]] == et["filter"]]
|
||||||
|
else:
|
||||||
|
entity_data = data
|
||||||
|
yield entity_name, entity_template_tags, entity_data
|
||||||
|
|
||||||
|
|
||||||
|
def nlu_samples_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
|
||||||
|
yaml = YAML()
|
||||||
|
nlu_data = yaml.load(nlu_data_file.read_text())
|
||||||
|
sm = {s["name"]: s for s in nlu_data["samples_per_entity"]}
|
||||||
|
return sm
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def compute_unique_nlu_stats(
|
||||||
|
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
|
||||||
|
):
|
||||||
|
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
|
||||||
|
nlu_data_file
|
||||||
|
):
|
||||||
|
entity_count = len(unique_entity_list(entity_template_tags, entity_data))
|
||||||
|
print(f"{entity_name}\t{entity_count}")
|
||||||
|
|
||||||
|
|
||||||
|
def replace_entity(tmpl, value, tags):
|
||||||
|
result = tmpl
|
||||||
|
for t in tags:
|
||||||
|
result = result.replace(t, value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def export_nlu_conv_json(
|
||||||
|
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
|
||||||
|
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
||||||
|
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
|
||||||
|
):
|
||||||
|
from .utils import ExtendedPath
|
||||||
|
from random import sample
|
||||||
|
|
||||||
|
entity_samples = nlu_samples_reader(nlu_data_file)
|
||||||
|
conv_data = ExtendedPath(conv_src).read_json()
|
||||||
|
conv_data["Dates"] = generate_dates()
|
||||||
|
result_dict = {}
|
||||||
|
data_count = 0
|
||||||
|
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
|
||||||
|
nlu_data_file
|
||||||
|
):
|
||||||
|
entity_variants = sample(conv_data[entity_name], entity_samples[entity_name]["test_size"])
|
||||||
|
unique_entites = unique_entity_list(entity_template_tags, entity_data)
|
||||||
|
# sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
|
||||||
|
result_dict[entity_name] = []
|
||||||
|
for val in entity_variants:
|
||||||
|
sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
|
||||||
|
for tmpl in sample_entites:
|
||||||
|
result = replace_entity(tmpl, val, entity_template_tags)
|
||||||
|
result_dict[entity_name].append(result)
|
||||||
|
data_count += 1
|
||||||
|
print(f"Total of {data_count} variants generated")
|
||||||
|
ExtendedPath(conv_dest).write_json(result_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -23,7 +23,7 @@ def fixate_data(dataset_path: Path):
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path):
|
||||||
reader_list = []
|
reader_list = []
|
||||||
abs_manifest_path = Path("abs_manifest.json")
|
abs_manifest_path = Path("abs_manifest.json")
|
||||||
for dataset_path in src_dataset_paths:
|
for dataset_path in src_dataset_paths:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,180 @@
|
||||||
|
import typer
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
|
||||||
|
# from .utils import generate_dates, asr_test_writer
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(reg_path, coll, s3, call_meta_dir, city_code, test_path):
|
||||||
|
from time import sleep
|
||||||
|
import subprocess
|
||||||
|
from .utils import ExtendedPath, get_call_logs
|
||||||
|
|
||||||
|
coll.delete_many({"CallID": test_path.name})
|
||||||
|
# test_path = dump_dir / data_name / test_file
|
||||||
|
# "../saas_reg/regression/run.sh -f data/asr_data/call_upwork_test_cnd_cities/asr_test.reg"
|
||||||
|
test_output = subprocess.run(
|
||||||
|
["/bin/bash", "-c", f"{str(reg_path)} --addr [::]:15400 -f {str(test_path)}"]
|
||||||
|
)
|
||||||
|
if test_output.returncode != 0:
|
||||||
|
print("Error running test {test_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_meta():
|
||||||
|
call_meta = coll.find_one({"CallID": test_path.name})
|
||||||
|
if call_meta:
|
||||||
|
return call_meta
|
||||||
|
else:
|
||||||
|
sleep(2)
|
||||||
|
return get_meta()
|
||||||
|
|
||||||
|
call_meta = get_meta()
|
||||||
|
call_logs = get_call_logs(call_meta, s3, call_meta_dir)
|
||||||
|
call_events = call_logs["Events"]
|
||||||
|
|
||||||
|
test_data_path = test_path.with_suffix(".result.json")
|
||||||
|
test_data = ExtendedPath(test_data_path).read_json()
|
||||||
|
|
||||||
|
def is_final_asr_event_or_spoken(ev):
|
||||||
|
pld = json.loads(ev["Payload"])
|
||||||
|
return (
|
||||||
|
pld["AsrResult"]["Results"][0]["IsFinal"]
|
||||||
|
if ev["Type"] == "ASR_RESULT"
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_test_event(ev):
|
||||||
|
return (
|
||||||
|
ev["Author"] == "NLU"
|
||||||
|
or (ev["Author"] == "ASR" and is_final_asr_event_or_spoken(ev))
|
||||||
|
) and (ev["Type"] != "DEBUG")
|
||||||
|
|
||||||
|
test_evs = list(filter(is_test_event, call_events))
|
||||||
|
if len(test_evs) == 2:
|
||||||
|
try:
|
||||||
|
asr_payload = test_evs[0]["Payload"]
|
||||||
|
asr_result = json.loads(asr_payload)["AsrResult"]["Results"][0]
|
||||||
|
alt_tscripts = [alt["Transcript"] for alt in asr_result["Alternatives"]]
|
||||||
|
gcp_result = "|".join(alt_tscripts)
|
||||||
|
entity_asr = asr_result["AsrDynamicResults"][0]["Candidate"]["Transcript"]
|
||||||
|
nlu_payload = test_evs[1]["Payload"]
|
||||||
|
nlu_result_payload = json.loads(nlu_payload)["NluResults"]
|
||||||
|
entity = test_data[0]["entity"]
|
||||||
|
text = test_data[0]["text"]
|
||||||
|
audio_filepath = test_data[0]["audio_filepath"]
|
||||||
|
pretrained_asr = test_data[0]["pretrained_asr"]
|
||||||
|
nlu_entity = list(json.loads(nlu_result_payload)["Entities"].values())[0]
|
||||||
|
asr_entity = city_code[entity] if entity in city_code else "UNKNOWN"
|
||||||
|
entities_match = asr_entity == nlu_entity
|
||||||
|
result = "Success" if entities_match else "Fail"
|
||||||
|
return {
|
||||||
|
"expected_entity": entity,
|
||||||
|
"text": text,
|
||||||
|
"audio_filepath": audio_filepath,
|
||||||
|
"pretrained_asr": pretrained_asr,
|
||||||
|
"entity_asr": entity_asr,
|
||||||
|
"google_asr": gcp_result,
|
||||||
|
"nlu_result": nlu_result_payload,
|
||||||
|
"asr_entity": asr_entity,
|
||||||
|
"nlu_entity": nlu_entity,
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return {
|
||||||
|
"expected_entity": test_data[0]["entity"],
|
||||||
|
"text": test_data[0]["text"],
|
||||||
|
"audio_filepath": test_data[0]["audio_filepath"],
|
||||||
|
"pretrained_asr": test_data[0]["pretrained_asr"],
|
||||||
|
"entity_asr": "",
|
||||||
|
"google_asr": "",
|
||||||
|
"nlu_result": "",
|
||||||
|
"asr_entity": "",
|
||||||
|
"nlu_entity": "",
|
||||||
|
"result": "Error",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"expected_entity": test_data[0]["entity"],
|
||||||
|
"text": test_data[0]["text"],
|
||||||
|
"audio_filepath": test_data[0]["audio_filepath"],
|
||||||
|
"pretrained_asr": test_data[0]["pretrained_asr"],
|
||||||
|
"entity_asr": "",
|
||||||
|
"google_asr": "",
|
||||||
|
"nlu_result": "",
|
||||||
|
"asr_entity": "",
|
||||||
|
"nlu_entity": "",
|
||||||
|
"result": "Empty",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def evaluate_slu(
|
||||||
|
# conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
|
||||||
|
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
|
||||||
|
# extraction_key: str = "Cities",
|
||||||
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
|
call_meta_dir: Path = Path("./data/call_metas"),
|
||||||
|
test_file_pref: str = "asr_test",
|
||||||
|
mongo_uri: str = typer.Option(
|
||||||
|
"mongodb://localhost:27017/test.calls", show_default=True
|
||||||
|
),
|
||||||
|
test_results: Path = Path("./data/results.csv"),
|
||||||
|
airport_codes: Path = Path("./airports_code.csv"),
|
||||||
|
reg_path: Path = Path("../saas_reg/regression/run.sh"),
|
||||||
|
test_id: str = "5ef481f27031edf6910e94e0",
|
||||||
|
):
|
||||||
|
# import json
|
||||||
|
from .utils import get_mongo_coll
|
||||||
|
import pandas as pd
|
||||||
|
import boto3
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
# import subprocess
|
||||||
|
# from time import sleep
|
||||||
|
import csv
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
df = pd.read_csv(airport_codes)[["iata", "city"]]
|
||||||
|
city_code = pd.Series(df["iata"].values, index=df["city"]).to_dict()
|
||||||
|
|
||||||
|
test_files = list((dump_dir / data_name).glob(test_file_pref + "*.reg"))
|
||||||
|
coll = get_mongo_coll(mongo_uri)
|
||||||
|
with test_results.open("w") as csvfile:
|
||||||
|
fieldnames = [
|
||||||
|
"expected_entity",
|
||||||
|
"text",
|
||||||
|
"audio_filepath",
|
||||||
|
"pretrained_asr",
|
||||||
|
"entity_asr",
|
||||||
|
"google_asr",
|
||||||
|
"nlu_result",
|
||||||
|
"asr_entity",
|
||||||
|
"nlu_entity",
|
||||||
|
"result",
|
||||||
|
]
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
with ThreadPoolExecutor(max_workers=8) as exe:
|
||||||
|
print("starting all loading tasks")
|
||||||
|
for test_result in tqdm(
|
||||||
|
exe.map(
|
||||||
|
partial(run_test, reg_path, coll, s3, call_meta_dir, city_code),
|
||||||
|
test_files,
|
||||||
|
),
|
||||||
|
position=0,
|
||||||
|
leave=True,
|
||||||
|
total=len(test_files),
|
||||||
|
):
|
||||||
|
writer.writerow(test_result)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,99 @@
|
||||||
|
import typer
|
||||||
|
from pathlib import Path
|
||||||
|
from .utils import generate_dates, asr_test_writer
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def export_test_reg(
|
||||||
|
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
|
||||||
|
data_name: str = typer.Option("call_upwork_test_cnd_cities", show_default=True),
|
||||||
|
extraction_key: str = "Cities",
|
||||||
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
|
dump_file: Path = Path("ui_dump.json"),
|
||||||
|
manifest_file: Path = Path("manifest.json"),
|
||||||
|
test_file: Path = Path("asr_test.reg"),
|
||||||
|
):
|
||||||
|
from .utils import (
|
||||||
|
ExtendedPath,
|
||||||
|
asr_manifest_reader,
|
||||||
|
gcp_transcribe_gen,
|
||||||
|
parallel_apply,
|
||||||
|
)
|
||||||
|
from ..client import transcribe_gen
|
||||||
|
from pydub import AudioSegment
|
||||||
|
from queue import PriorityQueue
|
||||||
|
|
||||||
|
jasper_map = {
|
||||||
|
"PNRs": 8045,
|
||||||
|
"Cities": 8046,
|
||||||
|
"Names": 8047,
|
||||||
|
"Dates": 8048,
|
||||||
|
}
|
||||||
|
# jasper_map = {"PNRs": 8050, "Cities": 8050, "Names": 8050, "Dates": 8050}
|
||||||
|
transcriber_gcp = gcp_transcribe_gen()
|
||||||
|
transcriber_trained = transcribe_gen(asr_port=jasper_map[extraction_key])
|
||||||
|
transcriber_all_trained = transcribe_gen(asr_port=8050)
|
||||||
|
transcriber_libri_all_trained = transcribe_gen(asr_port=8051)
|
||||||
|
|
||||||
|
def find_ent(dd, conv_data):
|
||||||
|
ents = PriorityQueue()
|
||||||
|
for ent in conv_data:
|
||||||
|
if ent in dd["text"]:
|
||||||
|
ents.put((-len(ent), ent))
|
||||||
|
return ents.get_nowait()[1]
|
||||||
|
|
||||||
|
def process_data(d):
|
||||||
|
orig_seg = AudioSegment.from_wav(d["audio_path"])
|
||||||
|
jas_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(24000)
|
||||||
|
gcp_seg = orig_seg.set_channels(1).set_sample_width(2).set_frame_rate(16000)
|
||||||
|
deepgram_file = Path("/home/shubham/voice_auto/pnrs/wav/") / Path(
|
||||||
|
d["audio_path"].stem + ".txt"
|
||||||
|
)
|
||||||
|
if deepgram_file.exists():
|
||||||
|
d["deepgram"] = "".join(
|
||||||
|
[s.replace("CHANNEL 0:", "") for s in deepgram_file.read_text().split("\n")]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
d["deepgram"] = 'Not Found'
|
||||||
|
d["audio_path"] = str(d["audio_path"])
|
||||||
|
d["gcp_transcript"] = transcriber_gcp(gcp_seg.raw_data)
|
||||||
|
d["jasper_trained"] = transcriber_trained(jas_seg.raw_data)
|
||||||
|
d["jasper_all"] = transcriber_all_trained(jas_seg.raw_data)
|
||||||
|
d["jasper_libri"] = transcriber_libri_all_trained(jas_seg.raw_data)
|
||||||
|
return d
|
||||||
|
|
||||||
|
conv_data = ExtendedPath(conv_src).read_json()
|
||||||
|
conv_data["Dates"] = generate_dates()
|
||||||
|
|
||||||
|
dump_data_path = dump_dir / Path(data_name) / dump_file
|
||||||
|
ui_dump_data = ExtendedPath(dump_data_path).read_json()["data"]
|
||||||
|
ui_dump_map = {i["utterance_id"]: i for i in ui_dump_data}
|
||||||
|
manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||||
|
test_points = list(asr_manifest_reader(manifest_path))
|
||||||
|
test_data_objs = [{**(ui_dump_map[t["audio_path"].stem]), **t} for t in test_points]
|
||||||
|
test_data = parallel_apply(process_data, test_data_objs)
|
||||||
|
# test_data = [process_data(t) for t in test_data_objs]
|
||||||
|
test_path = dump_dir / Path(data_name) / test_file
|
||||||
|
|
||||||
|
def dd_gen(dump_data):
|
||||||
|
for dd in dump_data:
|
||||||
|
ent = find_ent(dd, conv_data[extraction_key])
|
||||||
|
dd["entity"] = ent
|
||||||
|
if ent:
|
||||||
|
yield dd
|
||||||
|
|
||||||
|
asr_test_writer(test_path, dd_gen(test_data))
|
||||||
|
# for i, b in enumerate(batch(test_data, 1)):
|
||||||
|
# test_fname = Path(f"{test_file.stem}_{i}.reg")
|
||||||
|
# test_path = dump_dir / Path(data_name) / test_fname
|
||||||
|
# asr_test_writer(test_path, dd_gen(test_data))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -1,92 +0,0 @@
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
|
|
||||||
def compute_pnr_name_city():
|
|
||||||
data = pd.read_csv("./customer_utterance_processing/customer_provide_answer.csv")
|
|
||||||
|
|
||||||
def unique_pnr_count():
|
|
||||||
pnr_data = data[data["Input.Answer"] == "ZZZZZZ"]
|
|
||||||
unique_pnr_set = {
|
|
||||||
t
|
|
||||||
for n in range(1, 5)
|
|
||||||
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
|
|
||||||
if "ZZZZZZ" in t
|
|
||||||
}
|
|
||||||
return len(unique_pnr_set)
|
|
||||||
|
|
||||||
def unique_name_count():
|
|
||||||
pnr_data = data[data["Input.Answer"] == "John Doe"]
|
|
||||||
unique_pnr_set = {
|
|
||||||
t
|
|
||||||
for n in range(1, 5)
|
|
||||||
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
|
|
||||||
if "John Doe" in t
|
|
||||||
}
|
|
||||||
return len(unique_pnr_set)
|
|
||||||
|
|
||||||
def unique_city_count():
|
|
||||||
pnr_data = data[data["Input.Answer"] == "Heathrow Airport"]
|
|
||||||
unique_pnr_set = {
|
|
||||||
t
|
|
||||||
for n in range(1, 5)
|
|
||||||
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
|
|
||||||
if "Heathrow Airport" in t
|
|
||||||
}
|
|
||||||
return len(unique_pnr_set)
|
|
||||||
|
|
||||||
def unique_entity_count(entity_template_tags):
|
|
||||||
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
|
|
||||||
entity_data = data
|
|
||||||
unique_entity_set = {
|
|
||||||
t
|
|
||||||
for n in range(1, 5)
|
|
||||||
for t in entity_data[f"Answer.utterance-{n}"].tolist()
|
|
||||||
if any(et in t for et in entity_template_tags)
|
|
||||||
}
|
|
||||||
return len(unique_entity_set)
|
|
||||||
|
|
||||||
print('PNR', unique_pnr_count())
|
|
||||||
print('Name', unique_name_count())
|
|
||||||
print('City', unique_city_count())
|
|
||||||
print('Payment', unique_entity_count(['KPay', 'ZPay', 'Credit Card']))
|
|
||||||
|
|
||||||
|
|
||||||
def compute_date():
|
|
||||||
entity_template_tags = ['27 january', 'December 18']
|
|
||||||
data = pd.read_csv("./customer_utterance_processing/customer_provide_departure.csv")
|
|
||||||
# data.sample(10)
|
|
||||||
|
|
||||||
def unique_entity_count(entity_template_tags):
|
|
||||||
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
|
|
||||||
entity_data = data
|
|
||||||
unique_entity_set = {
|
|
||||||
t
|
|
||||||
for n in range(1, 5)
|
|
||||||
for t in entity_data[f"Answer.utterance-{n}"].tolist()
|
|
||||||
if any(et in t for et in entity_template_tags)
|
|
||||||
}
|
|
||||||
return len(unique_entity_set)
|
|
||||||
|
|
||||||
print('Date', unique_entity_count(entity_template_tags))
|
|
||||||
|
|
||||||
|
|
||||||
def compute_option():
|
|
||||||
entity_template_tag = 'third'
|
|
||||||
data = pd.read_csv("./customer_utterance_processing/customer_provide_flight_selection.csv")
|
|
||||||
|
|
||||||
def unique_entity_count():
|
|
||||||
entity_data = data[data['Input.Prompt'] == entity_template_tag]
|
|
||||||
unique_entity_set = {
|
|
||||||
t
|
|
||||||
for n in range(1, 5)
|
|
||||||
for t in entity_data[f"Answer.utterance-{n}"].tolist()
|
|
||||||
if entity_template_tag in t
|
|
||||||
}
|
|
||||||
return len(unique_entity_set)
|
|
||||||
|
|
||||||
print('Option', unique_entity_count())
|
|
||||||
|
|
||||||
|
|
||||||
compute_pnr_name_city()
|
|
||||||
compute_date()
|
|
||||||
compute_option()
|
|
||||||
|
|
@ -1,13 +1,18 @@
|
||||||
import numpy as np
|
|
||||||
import wave
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import wave
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from itertools import product
|
||||||
|
from functools import partial
|
||||||
|
from math import floor
|
||||||
|
from uuid import uuid4
|
||||||
|
from urllib.parse import urlsplit
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pymongo
|
import pymongo
|
||||||
from slugify import slugify
|
from slugify import slugify
|
||||||
from uuid import uuid4
|
|
||||||
from num2words import num2words
|
from num2words import num2words
|
||||||
from jasper.client import transcribe_gen
|
from jasper.client import transcribe_gen
|
||||||
from nemo.collections.asr.metrics import word_error_rate
|
from nemo.collections.asr.metrics import word_error_rate
|
||||||
|
|
@ -15,8 +20,6 @@ import matplotlib.pyplot as plt
|
||||||
import librosa
|
import librosa
|
||||||
import librosa.display
|
import librosa.display
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from functools import partial
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
|
|
||||||
def manifest_str(path, dur, text):
|
def manifest_str(path, dur, text):
|
||||||
|
|
@ -59,6 +62,10 @@ def alnum_to_asr_tokens(text):
|
||||||
return ("".join(num_tokens)).lower()
|
return ("".join(num_tokens)).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def tscript_uuid_fname(transcript):
|
||||||
|
return str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
||||||
|
|
||||||
|
|
||||||
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||||
dataset_dir = output_dir / Path(dataset_name)
|
dataset_dir = output_dir / Path(dataset_name)
|
||||||
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
(dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -67,7 +74,7 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False):
|
||||||
with asr_manifest.open("w") as mf:
|
with asr_manifest.open("w") as mf:
|
||||||
print(f"writing manifest to {asr_manifest}")
|
print(f"writing manifest to {asr_manifest}")
|
||||||
for transcript, audio_dur, wav_data in asr_data_source:
|
for transcript, audio_dur, wav_data in asr_data_source:
|
||||||
fname = str(uuid4()) + "_" + slugify(transcript, max_length=8)
|
fname = tscript_uuid_fname(transcript)
|
||||||
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav")
|
||||||
audio_file.write_bytes(wav_data)
|
audio_file.write_bytes(wav_data)
|
||||||
rel_pnr_path = audio_file.relative_to(dataset_dir)
|
rel_pnr_path = audio_file.relative_to(dataset_dir)
|
||||||
|
|
@ -93,6 +100,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
||||||
"data": [],
|
"data": [],
|
||||||
}
|
}
|
||||||
data_funcs = []
|
data_funcs = []
|
||||||
|
transcriber_gcp = gcp_transcribe_gen()
|
||||||
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
transcriber_pretrained = transcribe_gen(asr_port=8044)
|
||||||
with asr_manifest.open("w") as mf:
|
with asr_manifest.open("w") as mf:
|
||||||
print(f"writing manifest to {asr_manifest}")
|
print(f"writing manifest to {asr_manifest}")
|
||||||
|
|
@ -109,6 +117,8 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
||||||
rel_pnr_path,
|
rel_pnr_path,
|
||||||
):
|
):
|
||||||
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
pretrained_result = transcriber_pretrained(aud_seg.raw_data)
|
||||||
|
gcp_seg = aud_seg.set_frame_rate(16000)
|
||||||
|
gcp_result = transcriber_gcp(gcp_seg.raw_data)
|
||||||
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
pretrained_wer = word_error_rate([transcript], [pretrained_result])
|
||||||
wav_plot_path = (
|
wav_plot_path = (
|
||||||
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
|
||||||
|
|
@ -124,6 +134,7 @@ def ui_dump_manifest_writer(output_dir, dataset_name, asr_data_source, verbose=F
|
||||||
"spoken": transcript,
|
"spoken": transcript,
|
||||||
"caller": caller_name,
|
"caller": caller_name,
|
||||||
"utterance_id": fname,
|
"utterance_id": fname,
|
||||||
|
"gcp_asr": gcp_result,
|
||||||
"pretrained_asr": pretrained_result,
|
"pretrained_asr": pretrained_result,
|
||||||
"pretrained_wer": pretrained_wer,
|
"pretrained_wer": pretrained_wer,
|
||||||
"plot_path": str(wav_plot_path),
|
"plot_path": str(wav_plot_path),
|
||||||
|
|
@ -174,7 +185,7 @@ def asr_manifest_reader(data_manifest_path: Path):
|
||||||
pnr_data = [json.loads(v) for v in pnr_jsonl]
|
pnr_data = [json.loads(v) for v in pnr_jsonl]
|
||||||
for p in pnr_data:
|
for p in pnr_data:
|
||||||
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"])
|
||||||
p["chars"] = Path(p["audio_filepath"]).stem
|
p["text"] = p["text"].strip()
|
||||||
yield p
|
yield p
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -188,6 +199,32 @@ def asr_manifest_writer(asr_manifest_path: Path, manifest_str_source):
|
||||||
mf.write(manifest)
|
mf.write(manifest)
|
||||||
|
|
||||||
|
|
||||||
|
def asr_test_writer(out_file_path: Path, source):
|
||||||
|
def dd_str(dd, idx):
|
||||||
|
path = dd["audio_filepath"]
|
||||||
|
# dur = dd["duration"]
|
||||||
|
# return f"SAY {idx}\nPAUSE 3\nPLAY {path}\nPAUSE 3\n\n"
|
||||||
|
return f"PAUSE 2\nPLAY {path}\nPAUSE 60\n\n"
|
||||||
|
|
||||||
|
res_file = out_file_path.with_suffix(".result.json")
|
||||||
|
with out_file_path.open("w") as of:
|
||||||
|
print(f"opening {out_file_path} for writing test")
|
||||||
|
results = []
|
||||||
|
idx = 0
|
||||||
|
for ui_dd in source:
|
||||||
|
results.append(ui_dd)
|
||||||
|
out_str = dd_str(ui_dd, idx)
|
||||||
|
of.write(out_str)
|
||||||
|
idx += 1
|
||||||
|
of.write("DO_HANGUP\n")
|
||||||
|
ExtendedPath(res_file).write_json(results)
|
||||||
|
|
||||||
|
|
||||||
|
def batch(iterable, n=1):
|
||||||
|
ls = len(iterable)
|
||||||
|
return [iterable[ndx : min(ndx + n, ls)] for ndx in range(0, ls, n)]
|
||||||
|
|
||||||
|
|
||||||
class ExtendedPath(type(Path())):
|
class ExtendedPath(type(Path())):
|
||||||
"""docstring for ExtendedPath."""
|
"""docstring for ExtendedPath."""
|
||||||
|
|
||||||
|
|
@ -234,6 +271,219 @@ def plot_seg(wav_plot_path, audio_path):
|
||||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dates():
|
||||||
|
|
||||||
|
days = [i for i in range(1, 32)]
|
||||||
|
months = [
|
||||||
|
"January",
|
||||||
|
"February",
|
||||||
|
"March",
|
||||||
|
"April",
|
||||||
|
"May",
|
||||||
|
"June",
|
||||||
|
"July",
|
||||||
|
"August",
|
||||||
|
"September",
|
||||||
|
"October",
|
||||||
|
"November",
|
||||||
|
"December",
|
||||||
|
]
|
||||||
|
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
|
||||||
|
|
||||||
|
def ordinal(n):
|
||||||
|
return "%d%s" % (
|
||||||
|
n,
|
||||||
|
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
|
||||||
|
)
|
||||||
|
|
||||||
|
def canon_vars(d, m):
|
||||||
|
return [
|
||||||
|
ordinal(d) + " " + m,
|
||||||
|
m + " " + ordinal(d),
|
||||||
|
ordinal(d) + " of " + m,
|
||||||
|
m + " the " + ordinal(d),
|
||||||
|
str(d) + " " + m,
|
||||||
|
m + " " + str(d),
|
||||||
|
]
|
||||||
|
|
||||||
|
return [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
|
||||||
|
|
||||||
|
|
||||||
|
def get_call_logs(call_obj, s3, call_meta_dir):
|
||||||
|
meta_s3_uri = call_obj["DataURI"]
|
||||||
|
s3_event_url_p = urlsplit(meta_s3_uri)
|
||||||
|
saved_meta_path = call_meta_dir / Path(Path(s3_event_url_p.path).name)
|
||||||
|
if not saved_meta_path.exists():
|
||||||
|
print(f"downloading : {saved_meta_path} from {meta_s3_uri}")
|
||||||
|
s3.download_file(
|
||||||
|
s3_event_url_p.netloc, s3_event_url_p.path[1:], str(saved_meta_path)
|
||||||
|
)
|
||||||
|
call_metas = json.load(saved_meta_path.open())
|
||||||
|
return call_metas
|
||||||
|
|
||||||
|
|
||||||
|
def gcp_transcribe_gen():
|
||||||
|
from google.cloud import speech_v1
|
||||||
|
from google.cloud.speech_v1 import enums
|
||||||
|
|
||||||
|
# import io
|
||||||
|
client = speech_v1.SpeechClient()
|
||||||
|
# local_file_path = 'resources/brooklyn_bridge.raw'
|
||||||
|
|
||||||
|
# The language of the supplied audio
|
||||||
|
language_code = "en-US"
|
||||||
|
model = "phone_call"
|
||||||
|
|
||||||
|
# Sample rate in Hertz of the audio data sent
|
||||||
|
sample_rate_hertz = 16000
|
||||||
|
|
||||||
|
# Encoding of audio data sent. This sample sets this explicitly.
|
||||||
|
# This field is optional for FLAC and WAV audio formats.
|
||||||
|
encoding = enums.RecognitionConfig.AudioEncoding.LINEAR16
|
||||||
|
config = {
|
||||||
|
"language_code": language_code,
|
||||||
|
"sample_rate_hertz": sample_rate_hertz,
|
||||||
|
"encoding": encoding,
|
||||||
|
"model": model,
|
||||||
|
"enable_automatic_punctuation": True,
|
||||||
|
"max_alternatives": 10,
|
||||||
|
"enable_word_time_offsets": True, # used to detect start and end time of utterances
|
||||||
|
"speech_contexts": [
|
||||||
|
{
|
||||||
|
"phrases": [
|
||||||
|
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||||
|
"$OOV_CLASS_DIGIT_SEQUENCE",
|
||||||
|
"$TIME",
|
||||||
|
"$YEAR",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"phrases": [
|
||||||
|
"A",
|
||||||
|
"B",
|
||||||
|
"C",
|
||||||
|
"D",
|
||||||
|
"E",
|
||||||
|
"F",
|
||||||
|
"G",
|
||||||
|
"H",
|
||||||
|
"I",
|
||||||
|
"J",
|
||||||
|
"K",
|
||||||
|
"L",
|
||||||
|
"M",
|
||||||
|
"N",
|
||||||
|
"O",
|
||||||
|
"P",
|
||||||
|
"Q",
|
||||||
|
"R",
|
||||||
|
"S",
|
||||||
|
"T",
|
||||||
|
"U",
|
||||||
|
"V",
|
||||||
|
"W",
|
||||||
|
"X",
|
||||||
|
"Y",
|
||||||
|
"Z",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"phrases": [
|
||||||
|
"PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||||
|
"my PNR is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||||
|
"my PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||||
|
"PNR number is $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||||
|
"It's $OOV_CLASS_ALPHANUMERIC_SEQUENCE",
|
||||||
|
"$OOV_CLASS_ALPHANUMERIC_SEQUENCE is my PNR",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{"phrases": ["my name is"]},
|
||||||
|
{"phrases": ["Number $ORDINAL", "Numeral $ORDINAL"]},
|
||||||
|
{
|
||||||
|
"phrases": [
|
||||||
|
"John Smith",
|
||||||
|
"Carina Hu",
|
||||||
|
"Travis Lim",
|
||||||
|
"Marvin Tan",
|
||||||
|
"Samuel Tan",
|
||||||
|
"Dawn Mathew",
|
||||||
|
"Dawn",
|
||||||
|
"Mathew",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"phrases": [
|
||||||
|
"Beijing",
|
||||||
|
"Tokyo",
|
||||||
|
"London",
|
||||||
|
"19 August",
|
||||||
|
"7 October",
|
||||||
|
"11 December",
|
||||||
|
"17 September",
|
||||||
|
"19th August",
|
||||||
|
"7th October",
|
||||||
|
"11th December",
|
||||||
|
"17th September",
|
||||||
|
"ABC123",
|
||||||
|
"KWXUNP",
|
||||||
|
"XLU5K1",
|
||||||
|
"WL2JV6",
|
||||||
|
"KBS651",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"phrases": [
|
||||||
|
"first flight",
|
||||||
|
"second flight",
|
||||||
|
"third flight",
|
||||||
|
"first option",
|
||||||
|
"second option",
|
||||||
|
"third option",
|
||||||
|
"first one",
|
||||||
|
"second one",
|
||||||
|
"third one",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"industry_naics_code_of_audio": 481111,
|
||||||
|
"interaction_type": enums.RecognitionMetadata.InteractionType.PHONE_CALL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def sample_recognize(content):
|
||||||
|
"""
|
||||||
|
Transcribe a short audio file using synchronous speech recognition
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_file_path Path to local audio file, e.g. /path/audio.wav
|
||||||
|
"""
|
||||||
|
|
||||||
|
# with io.open(local_file_path, "rb") as f:
|
||||||
|
# content = f.read()
|
||||||
|
audio = {"content": content}
|
||||||
|
|
||||||
|
response = client.recognize(config, audio)
|
||||||
|
for result in response.results:
|
||||||
|
# First alternative is the most probable result
|
||||||
|
return "/".join([alt.transcript for alt in result.alternatives])
|
||||||
|
# print(u"Transcript: {}".format(alternative.transcript))
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return sample_recognize
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_apply(fn, iterable, workers=8):
|
||||||
|
with ThreadPoolExecutor(max_workers=workers) as exe:
|
||||||
|
print(f"parallelly applying {fn}")
|
||||||
|
return [
|
||||||
|
res
|
||||||
|
for res in tqdm(
|
||||||
|
exe.map(fn, iterable), position=0, leave=True, total=len(iterable)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
for c in random_pnr_generator():
|
for c in random_pnr_generator():
|
||||||
print(c)
|
print(c)
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from ..utils import (
|
||||||
ExtendedPath,
|
ExtendedPath,
|
||||||
asr_manifest_reader,
|
asr_manifest_reader,
|
||||||
asr_manifest_writer,
|
asr_manifest_writer,
|
||||||
|
tscript_uuid_fname,
|
||||||
get_mongo_conn,
|
get_mongo_conn,
|
||||||
plot_seg,
|
plot_seg,
|
||||||
)
|
)
|
||||||
|
|
@ -180,14 +181,16 @@ def task_ui(
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def dump_corrections(
|
def dump_corrections(
|
||||||
|
task_uid: str,
|
||||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||||
dump_dir: Path = Path("./data/asr_data"),
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
dump_fname: Path = Path("corrections.json"),
|
dump_fname: Path = Path("corrections.json"),
|
||||||
):
|
):
|
||||||
dump_path = dump_dir / Path(data_name) / dump_fname
|
dump_path = dump_dir / Path(data_name) / dump_fname
|
||||||
col = get_mongo_conn(col="asr_validation")
|
col = get_mongo_conn(col="asr_validation")
|
||||||
|
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
||||||
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
||||||
|
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
|
||||||
corrections = [c for c in cursor_obj]
|
corrections = [c for c in cursor_obj]
|
||||||
ExtendedPath(dump_path).write_json(corrections)
|
ExtendedPath(dump_path).write_json(corrections)
|
||||||
|
|
||||||
|
|
@ -256,49 +259,36 @@ class ExtractionType(str, Enum):
|
||||||
date = "dates"
|
date = "dates"
|
||||||
city = "cities"
|
city = "cities"
|
||||||
name = "names"
|
name = "names"
|
||||||
|
all = "all"
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def split_extract(
|
def split_extract(
|
||||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||||
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
|
# dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True),
|
||||||
dump_dir: Path = Path("./data/valiation_data"),
|
# dump_dir: Path = Path("./data/valiation_data"),
|
||||||
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
dump_file: Path = Path("ui_dump.json"),
|
dump_file: Path = Path("ui_dump.json"),
|
||||||
manifest_dir: Path = Path("./data/asr_data"),
|
|
||||||
manifest_file: Path = Path("manifest.json"),
|
manifest_file: Path = Path("manifest.json"),
|
||||||
corrections_file: Path = Path("corrections.json"),
|
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
||||||
conv_data_path: Path = Path("./data/conv_data.json"),
|
conv_data_path: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
||||||
extraction_type: ExtractionType = ExtractionType.date,
|
extraction_type: ExtractionType = ExtractionType.all,
|
||||||
):
|
):
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
def get_conv_data(cdp):
|
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||||
from itertools import product
|
conv_data = ExtendedPath(conv_data_path).read_json()
|
||||||
|
|
||||||
conv_data = json.load(cdp.open())
|
def extract_data_of_type(extraction_key):
|
||||||
days = [str(i) for i in range(1, 32)]
|
extraction_vals = conv_data[extraction_key]
|
||||||
months = conv_data["months"]
|
dest_data_name = data_name + "_" + extraction_key.lower()
|
||||||
day_months = {d + " " + m for d, m in product(days, months)}
|
|
||||||
return {
|
|
||||||
"cities": set(conv_data["cities"]),
|
|
||||||
"names": set(conv_data["names"]),
|
|
||||||
"dates": day_months,
|
|
||||||
}
|
|
||||||
|
|
||||||
dest_data_name = data_name + "_" + extraction_type.value
|
|
||||||
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
|
|
||||||
conv_data = get_conv_data(conv_data_path)
|
|
||||||
extraction_vals = conv_data[extraction_type.value]
|
|
||||||
|
|
||||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||||
dest_data_dir = manifest_dir / Path(dest_data_name)
|
dest_data_dir = dump_dir / Path(dest_data_name)
|
||||||
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
||||||
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
||||||
dest_manifest_path = dest_data_dir / manifest_file
|
dest_manifest_path = dest_data_dir / manifest_file
|
||||||
dest_ui_dir = dump_dir / Path(dest_data_name)
|
dest_ui_path = dest_data_dir / dump_file
|
||||||
dest_ui_dir.mkdir(exist_ok=True, parents=True)
|
|
||||||
dest_ui_path = dest_ui_dir / dump_file
|
|
||||||
dest_correction_path = dest_ui_dir / corrections_file
|
|
||||||
|
|
||||||
def extract_manifest(mg):
|
def extract_manifest(mg):
|
||||||
for m in mg:
|
for m in mg:
|
||||||
|
|
@ -309,14 +299,21 @@ def split_extract(
|
||||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||||
|
|
||||||
ui_data_path = dump_dir / Path(data_name) / dump_file
|
ui_data_path = dump_dir / Path(data_name) / dump_file
|
||||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
orig_ui_data = ExtendedPath(ui_data_path).read_json()
|
||||||
ui_data = json.load(ui_data_path.open())["data"]
|
ui_data = orig_ui_data["data"]
|
||||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||||
corrections = json.load(corrections_path.open())
|
|
||||||
|
|
||||||
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
||||||
ExtendedPath(dest_ui_path).write_json(extracted_ui_data)
|
final_data = []
|
||||||
|
for i, d in enumerate(extracted_ui_data):
|
||||||
|
d['real_idx'] = i
|
||||||
|
final_data.append(d)
|
||||||
|
orig_ui_data['data'] = final_data
|
||||||
|
ExtendedPath(dest_ui_path).write_json(orig_ui_data)
|
||||||
|
|
||||||
|
if corrections_file:
|
||||||
|
dest_correction_path = dest_data_dir / corrections_file
|
||||||
|
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||||
|
corrections = json.load(corrections_path.open())
|
||||||
extracted_corrections = list(
|
extracted_corrections = list(
|
||||||
filter(
|
filter(
|
||||||
lambda c: c["code"] in file_ui_map
|
lambda c: c["code"] in file_ui_map
|
||||||
|
|
@ -326,23 +323,29 @@ def split_extract(
|
||||||
)
|
)
|
||||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
||||||
|
|
||||||
|
if extraction_type.value == 'all':
|
||||||
|
for ext_key in conv_data.keys():
|
||||||
|
extract_data_of_type(ext_key)
|
||||||
|
else:
|
||||||
|
extract_data_of_type(extraction_type.value)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def update_corrections(
|
def update_corrections(
|
||||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||||
dump_dir: Path = Path("./data/valiation_data"),
|
dump_dir: Path = Path("./data/asr_data"),
|
||||||
manifest_dir: Path = Path("./data/asr_data"),
|
|
||||||
manifest_file: Path = Path("manifest.json"),
|
manifest_file: Path = Path("manifest.json"),
|
||||||
corrections_file: Path = Path("corrections.json"),
|
corrections_file: Path = Path("corrections.json"),
|
||||||
# data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
|
ui_dump_file: Path = Path("ui_dump.json"),
|
||||||
# corrections_path: Path = Path("./data/valiation_data/corrections.json"),
|
skip_incorrect: bool = typer.Option(True, show_default=True),
|
||||||
skip_incorrect: bool = True,
|
|
||||||
):
|
):
|
||||||
data_manifest_path = manifest_dir / Path(data_name) / manifest_file
|
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||||
|
ui_dump_path = dump_dir / Path(data_name) / ui_dump_file
|
||||||
|
|
||||||
def correct_manifest(manifest_data_gen, corrections_path):
|
def correct_manifest(ui_dump_path, corrections_path):
|
||||||
corrections = json.load(corrections_path.open())
|
corrections = ExtendedPath(corrections_path).read_json()
|
||||||
|
ui_data = ExtendedPath(ui_dump_path).read_json()['data']
|
||||||
correct_set = {
|
correct_set = {
|
||||||
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
c["code"] for c in corrections if c["value"]["status"] == "Correct"
|
||||||
}
|
}
|
||||||
|
|
@ -355,36 +358,38 @@ def update_corrections(
|
||||||
# for d in manifest_data_gen:
|
# for d in manifest_data_gen:
|
||||||
# if d["chars"] in incorrect_set:
|
# if d["chars"] in incorrect_set:
|
||||||
# d["audio_path"].unlink()
|
# d["audio_path"].unlink()
|
||||||
renamed_set = set()
|
# renamed_set = set()
|
||||||
for d in manifest_data_gen:
|
for d in ui_data:
|
||||||
if d["chars"] in correct_set:
|
if d["utterance_id"] in correct_set:
|
||||||
yield {
|
yield {
|
||||||
"audio_filepath": d["audio_filepath"],
|
"audio_filepath": d["audio_filepath"],
|
||||||
"duration": d["duration"],
|
"duration": d["duration"],
|
||||||
"text": d["text"],
|
"text": d["text"],
|
||||||
}
|
}
|
||||||
elif d["chars"] in correction_map:
|
elif d["utterance_id"] in correction_map:
|
||||||
correct_text = correction_map[d["chars"]]
|
correct_text = correction_map[d["utterance_id"]]
|
||||||
if skip_incorrect:
|
if skip_incorrect:
|
||||||
print(
|
print(
|
||||||
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
f'skipping incorrect {d["audio_path"]} corrected to {correct_text}'
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
renamed_set.add(correct_text)
|
orig_audio_path = Path(d["audio_path"])
|
||||||
new_name = str(Path(correct_text).with_suffix(".wav"))
|
new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav"))
|
||||||
d["audio_path"].replace(d["audio_path"].with_name(new_name))
|
new_audio_path = orig_audio_path.with_name(new_name)
|
||||||
|
orig_audio_path.replace(new_audio_path)
|
||||||
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
new_filepath = str(Path(d["audio_filepath"]).with_name(new_name))
|
||||||
yield {
|
yield {
|
||||||
"audio_filepath": new_filepath,
|
"audio_filepath": new_filepath,
|
||||||
"duration": d["duration"],
|
"duration": d["duration"],
|
||||||
"text": alnum_to_asr_tokens(correct_text),
|
"text": correct_text,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
orig_audio_path = Path(d["audio_path"])
|
||||||
# don't delete if another correction points to an old file
|
# don't delete if another correction points to an old file
|
||||||
if d["chars"] not in renamed_set:
|
# if d["text"] not in renamed_set:
|
||||||
d["audio_path"].unlink()
|
orig_audio_path.unlink()
|
||||||
else:
|
# else:
|
||||||
print(f'skipping deletion of correction:{d["chars"]}')
|
# print(f'skipping deletion of correction:{d["text"]}')
|
||||||
|
|
||||||
typer.echo(f"Using data manifest:{data_manifest_path}")
|
typer.echo(f"Using data manifest:{data_manifest_path}")
|
||||||
dataset_dir = data_manifest_path.parent
|
dataset_dir = data_manifest_path.parent
|
||||||
|
|
@ -393,8 +398,8 @@ def update_corrections(
|
||||||
if not backup_dir.exists():
|
if not backup_dir.exists():
|
||||||
typer.echo(f"backing up to :{backup_dir}")
|
typer.echo(f"backing up to :{backup_dir}")
|
||||||
shutil.copytree(str(dataset_dir), str(backup_dir))
|
shutil.copytree(str(dataset_dir), str(backup_dir))
|
||||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
# manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||||
corrected_manifest = correct_manifest(manifest_gen, corrections_path)
|
corrected_manifest = correct_manifest(ui_dump_path, corrections_path)
|
||||||
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
new_data_manifest_path = data_manifest_path.with_name("manifest.new")
|
||||||
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
asr_manifest_writer(new_data_manifest_path, corrected_manifest)
|
||||||
new_data_manifest_path.replace(data_manifest_path)
|
new_data_manifest_path.replace(data_manifest_path)
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ def parse_args():
|
||||||
work_dir="./train/work",
|
work_dir="./train/work",
|
||||||
num_epochs=300,
|
num_epochs=300,
|
||||||
weight_decay=0.005,
|
weight_decay=0.005,
|
||||||
checkpoint_save_freq=200,
|
checkpoint_save_freq=100,
|
||||||
eval_freq=100,
|
eval_freq=100,
|
||||||
load_dir="./train/models/jasper/",
|
load_dir="./train/models/jasper/",
|
||||||
warmup_steps=3,
|
warmup_steps=3,
|
||||||
|
|
@ -266,6 +266,7 @@ def create_all_dags(args, neural_factory):
|
||||||
folder=neural_factory.checkpoint_dir,
|
folder=neural_factory.checkpoint_dir,
|
||||||
load_from_folder=args.load_dir,
|
load_from_folder=args.load_dir,
|
||||||
step_freq=args.checkpoint_save_freq,
|
step_freq=args.checkpoint_save_freq,
|
||||||
|
checkpoints_to_keep=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
callbacks = [train_callback, chpt_callback]
|
callbacks = [train_callback, chpt_callback]
|
||||||
|
|
|
||||||
4
setup.py
4
setup.py
|
|
@ -39,6 +39,7 @@ extra_requirements = {
|
||||||
"streamlit==0.58.0",
|
"streamlit==0.58.0",
|
||||||
"natural==0.2.0",
|
"natural==0.2.0",
|
||||||
"stringcase==1.2.0",
|
"stringcase==1.2.0",
|
||||||
|
"google-cloud-speech~=1.3.1",
|
||||||
]
|
]
|
||||||
# "train": [
|
# "train": [
|
||||||
# "torchaudio==0.5.0",
|
# "torchaudio==0.5.0",
|
||||||
|
|
@ -65,12 +66,15 @@ setup(
|
||||||
"jasper_trainer = jasper.training.cli:main",
|
"jasper_trainer = jasper.training.cli:main",
|
||||||
"jasper_data_tts_generate = jasper.data.tts_generator:main",
|
"jasper_data_tts_generate = jasper.data.tts_generator:main",
|
||||||
"jasper_data_conv_generate = jasper.data.conv_generator:main",
|
"jasper_data_conv_generate = jasper.data.conv_generator:main",
|
||||||
|
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
|
||||||
|
"jasper_data_test_generate = jasper.data.test_generator:main",
|
||||||
"jasper_data_call_recycle = jasper.data.call_recycler:main",
|
"jasper_data_call_recycle = jasper.data.call_recycler:main",
|
||||||
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
|
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
|
||||||
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
|
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
|
||||||
"jasper_data_server = jasper.data.server:main",
|
"jasper_data_server = jasper.data.server:main",
|
||||||
"jasper_data_validation = jasper.data.validation.process:main",
|
"jasper_data_validation = jasper.data.validation.process:main",
|
||||||
"jasper_data_preprocess = jasper.data.process:main",
|
"jasper_data_preprocess = jasper.data.process:main",
|
||||||
|
"jasper_data_slu_evaluate = jasper.data.slu_evaluator:main",
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue