1. split extract all data types in one shot with --extraction-type all flag

2. add notes about diffing split extracted and original data
3. add a nlu conv generator to generate conv data based on nlu utterances and entities
4. add task uid support for dumping corrections
5. abstracted generate date fn
Malar Kannan 2020-06-24 22:50:45 +05:30
parent e76ccda5dd
commit 515e9c1037
7 changed files with 196 additions and 174 deletions

5
Notes.md Normal file
View File

@ -0,0 +1,5 @@
> Diff after splitting based on type
```
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
```

View File

@ -1,8 +1,6 @@
import typer
from pathlib import Path
from random import randrange
from itertools import product
from math import floor
from .utils import generate_dates
app = typer.Typer()
@ -16,46 +14,7 @@ def export_conv_json(
conv_data = ExtendedPath(conv_src).read_json()
days = [i for i in range(1, 32)]
months = [
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
]
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
def ordinal(n):
return "%d%s" % (
n,
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
)
def canon_vars(d, m):
return [
ordinal(d) + " " + m,
m + " " + ordinal(d),
ordinal(d) + " of " + m,
m + " the " + ordinal(d),
str(d) + " " + m,
m + " " + str(d),
]
day_months = [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
conv_data["dates"] = day_months
def dates_data_gen():
i = randrange(len(day_months))
return day_months[i]
conv_data["dates"] = generate_dates()
ExtendedPath(conv_dest).write_json(conv_data)

View File

@ -0,0 +1,98 @@
from pathlib import Path
import typer
import pandas as pd
from ruamel.yaml import YAML
from itertools import product
from .utils import generate_dates
app = typer.Typer()
def unique_entity_list(entity_template_tags, entity_data):
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return list(unique_entity_set)
def nlu_entity_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
yaml = YAML()
nlu_data = yaml.load(nlu_data_file.read_text())
for cf in nlu_data["csv_files"]:
data = pd.read_csv(cf["fname"])
for et in cf["entities"]:
entity_name = et["name"]
entity_template_tags = et["tags"]
if "filter" in et:
entity_data = data[data[cf["filter_key"]] == et["filter"]]
else:
entity_data = data
yield entity_name, entity_template_tags, entity_data
def nlu_samples_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
yaml = YAML()
nlu_data = yaml.load(nlu_data_file.read_text())
sm = {s["name"]: s for s in nlu_data["samples_per_entity"]}
return sm
@app.command()
def compute_unique_nlu_stats(
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
):
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
nlu_data_file
):
entity_count = len(unique_entity_list(entity_template_tags, entity_data))
print(f"{entity_name}\t{entity_count}")
def replace_entity(tmpl, value, tags):
result = tmpl
for t in tags:
result = result.replace(t, value)
return result
@app.command()
def export_nlu_conv_json(
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
):
from .utils import ExtendedPath
from random import sample
entity_samples = nlu_samples_reader(nlu_data_file)
conv_data = ExtendedPath(conv_src).read_json()
conv_data["Dates"] = generate_dates()
result_dict = {}
data_count = 0
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
nlu_data_file
):
entity_variants = sample(conv_data[entity_name], entity_samples[entity_name]["test_size"])
unique_entites = unique_entity_list(entity_template_tags, entity_data)
# sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
result_dict[entity_name] = []
for val in entity_variants:
sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
for tmpl in sample_entites:
result = replace_entity(tmpl, val, entity_template_tags)
result_dict[entity_name].append(result)
data_count += 1
print(f"Total of {data_count} variants generated")
ExtendedPath(conv_dest).write_json(result_dict)
def main():
app()
if __name__ == "__main__":
main()

View File

@ -1,92 +0,0 @@
import pandas as pd
def compute_pnr_name_city():
data = pd.read_csv("./customer_utterance_processing/customer_provide_answer.csv")
def unique_pnr_count():
pnr_data = data[data["Input.Answer"] == "ZZZZZZ"]
unique_pnr_set = {
t
for n in range(1, 5)
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
if "ZZZZZZ" in t
}
return len(unique_pnr_set)
def unique_name_count():
pnr_data = data[data["Input.Answer"] == "John Doe"]
unique_pnr_set = {
t
for n in range(1, 5)
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
if "John Doe" in t
}
return len(unique_pnr_set)
def unique_city_count():
pnr_data = data[data["Input.Answer"] == "Heathrow Airport"]
unique_pnr_set = {
t
for n in range(1, 5)
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
if "Heathrow Airport" in t
}
return len(unique_pnr_set)
def unique_entity_count(entity_template_tags):
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
entity_data = data
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return len(unique_entity_set)
print('PNR', unique_pnr_count())
print('Name', unique_name_count())
print('City', unique_city_count())
print('Payment', unique_entity_count(['KPay', 'ZPay', 'Credit Card']))
def compute_date():
entity_template_tags = ['27 january', 'December 18']
data = pd.read_csv("./customer_utterance_processing/customer_provide_departure.csv")
# data.sample(10)
def unique_entity_count(entity_template_tags):
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
entity_data = data
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if any(et in t for et in entity_template_tags)
}
return len(unique_entity_set)
print('Date', unique_entity_count(entity_template_tags))
def compute_option():
entity_template_tag = 'third'
data = pd.read_csv("./customer_utterance_processing/customer_provide_flight_selection.csv")
def unique_entity_count():
entity_data = data[data['Input.Prompt'] == entity_template_tag]
unique_entity_set = {
t
for n in range(1, 5)
for t in entity_data[f"Answer.utterance-{n}"].tolist()
if entity_template_tag in t
}
return len(unique_entity_set)
print('Option', unique_entity_count())
compute_pnr_name_city()
compute_date()
compute_option()

View File

@ -1,13 +1,17 @@
import numpy as np
import wave
import io
import os
import json
import wave
from pathlib import Path
from itertools import product
from functools import partial
from math import floor
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import pymongo
from slugify import slugify
from uuid import uuid4
from num2words import num2words
from jasper.client import transcribe_gen
from nemo.collections.asr.metrics import word_error_rate
@ -15,8 +19,6 @@ import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm
from functools import partial
from concurrent.futures import ThreadPoolExecutor
def manifest_str(path, dur, text):
@ -238,6 +240,44 @@ def plot_seg(wav_plot_path, audio_path):
fig.savefig(wav_plot_f, format="png", dpi=50)
def generate_dates():
days = [i for i in range(1, 32)]
months = [
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
]
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
def ordinal(n):
return "%d%s" % (
n,
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
)
def canon_vars(d, m):
return [
ordinal(d) + " " + m,
m + " " + ordinal(d),
ordinal(d) + " of " + m,
m + " the " + ordinal(d),
str(d) + " " + m,
m + " " + str(d),
]
return [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
def main():
for c in random_pnr_generator():
print(c)

View File

@ -181,14 +181,16 @@ def task_ui(
@app.command()
def dump_corrections(
task_uid: str,
data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("corrections.json"),
):
dump_path = dump_dir / Path(data_name) / dump_fname
col = get_mongo_conn(col="asr_validation")
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@ -257,6 +259,7 @@ class ExtractionType(str, Enum):
date = "dates"
city = "cities"
name = "names"
all = "all"
@app.command()
@ -267,16 +270,18 @@ def split_extract(
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = Path("./data/conv_data.json"),
extraction_type: ExtractionType = ExtractionType.date,
extraction_type: ExtractionType = ExtractionType.all,
):
import shutil
dest_data_name = data_name + "_" + extraction_type.value
data_manifest_path = dump_dir / Path(data_name) / manifest_file
conv_data = ExtendedPath(conv_data_path).read_json()
extraction_vals = conv_data[extraction_type.value]
def extract_data_of_type(extraction_key):
extraction_vals = conv_data[extraction_key]
dest_data_name = data_name + "_" + extraction_key.lower()
manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = dump_dir / Path(dest_data_name)
@ -284,7 +289,6 @@ def split_extract(
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file
dest_ui_path = dest_data_dir / dump_file
dest_correction_path = dest_data_dir / corrections_file
def extract_manifest(mg):
for m in mg:
@ -295,14 +299,15 @@ def split_extract(
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
ui_data_path = dump_dir / Path(data_name) / dump_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_data = json.load(ui_data_path.open())["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
corrections = json.load(corrections_path.open())
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
ExtendedPath(dest_ui_path).write_json(extracted_ui_data)
if corrections_file:
dest_correction_path = dest_data_dir / corrections_file
corrections_path = dump_dir / Path(data_name) / corrections_file
corrections = json.load(corrections_path.open())
extracted_corrections = list(
filter(
lambda c: c["code"] in file_ui_map
@ -312,6 +317,12 @@ def split_extract(
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all':
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
extract_data_of_type(extraction_type.value)
@app.command()
def update_corrections(

View File

@ -65,6 +65,7 @@ setup(
"jasper_trainer = jasper.training.cli:main",
"jasper_data_tts_generate = jasper.data.tts_generator:main",
"jasper_data_conv_generate = jasper.data.conv_generator:main",
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
"jasper_data_call_recycle = jasper.data.call_recycler:main",
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",