From 515e9c1037c5cd942739bc8c0fe9f9b347ed860d Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Wed, 24 Jun 2020 22:50:45 +0530 Subject: [PATCH] 1. split extract all data types in one shot with --extraction-type all flag 2. add notes about diffing split extracted and original data 3. add a nlu conv generator to generate conv data based on nlu utterances and entities 4. add task uid support for dumping corrections 5. abstracted generate date fn --- Notes.md | 5 ++ jasper/data/conv_generator.py | 45 +------------- jasper/data/nlu_generator.py | 98 +++++++++++++++++++++++++++++++ jasper/data/unique_nlu.py | 92 ----------------------------- jasper/data/utils.py | 50 ++++++++++++++-- jasper/data/validation/process.py | 79 ++++++++++++++----------- setup.py | 1 + 7 files changed, 196 insertions(+), 174 deletions(-) create mode 100644 Notes.md create mode 100644 jasper/data/nlu_generator.py delete mode 100644 jasper/data/unique_nlu.py diff --git a/Notes.md b/Notes.md new file mode 100644 index 0000000..4195305 --- /dev/null +++ b/Notes.md @@ -0,0 +1,5 @@ + +> Diff after splitting based on type +``` +diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort) +``` diff --git a/jasper/data/conv_generator.py b/jasper/data/conv_generator.py index ef7c9c8..e03a240 100644 --- a/jasper/data/conv_generator.py +++ b/jasper/data/conv_generator.py @@ -1,8 +1,6 @@ import typer from pathlib import Path -from random import randrange -from itertools import product -from math import floor +from .utils import generate_dates app = typer.Typer() @@ -16,46 +14,7 @@ def export_conv_json( conv_data = ExtendedPath(conv_src).read_json() - days = [i for i in range(1, 32)] - months = [ - "January", - "February", - "March", - "April", - "May", - "June", - "July", - "August", - "September", - "October", - "November", - "December", - ] - # ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement - - def ordinal(n): - return "%d%s" % ( - n, - "tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4], - ) - - def canon_vars(d, m): - return [ - ordinal(d) + " " + m, - m + " " + ordinal(d), - ordinal(d) + " of " + m, - m + " the " + ordinal(d), - str(d) + " " + m, - m + " " + str(d), - ] - - day_months = [dm for d, m in product(days, months) for dm in canon_vars(d, m)] - - conv_data["dates"] = day_months - - def dates_data_gen(): - i = randrange(len(day_months)) - return day_months[i] + conv_data["dates"] = generate_dates() ExtendedPath(conv_dest).write_json(conv_data) diff --git a/jasper/data/nlu_generator.py b/jasper/data/nlu_generator.py new file mode 100644 index 0000000..293a497 --- /dev/null +++ b/jasper/data/nlu_generator.py @@ -0,0 +1,98 @@ +from pathlib import Path + +import typer +import pandas as pd +from ruamel.yaml import YAML +from itertools import product +from .utils import generate_dates + +app = typer.Typer() + + +def unique_entity_list(entity_template_tags, entity_data): + unique_entity_set = { + t + for n in range(1, 5) + for t in entity_data[f"Answer.utterance-{n}"].tolist() + if any(et in t for et in entity_template_tags) + } + return list(unique_entity_set) + + +def nlu_entity_reader(nlu_data_file: Path = Path("./nlu_data.yaml")): + yaml = YAML() + nlu_data = yaml.load(nlu_data_file.read_text()) + for cf in nlu_data["csv_files"]: + data = pd.read_csv(cf["fname"]) + for et in cf["entities"]: + entity_name = et["name"] + entity_template_tags = et["tags"] + if "filter" in et: + entity_data = data[data[cf["filter_key"]] == et["filter"]] + else: + entity_data = data + yield entity_name, entity_template_tags, entity_data + + +def nlu_samples_reader(nlu_data_file: Path = Path("./nlu_data.yaml")): + yaml = YAML() + nlu_data = yaml.load(nlu_data_file.read_text()) + sm = {s["name"]: s for s in nlu_data["samples_per_entity"]} + return sm + + +@app.command() +def compute_unique_nlu_stats( + nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True), +): + for entity_name, entity_template_tags, entity_data in nlu_entity_reader( + nlu_data_file + ): + entity_count = len(unique_entity_list(entity_template_tags, entity_data)) + print(f"{entity_name}\t{entity_count}") + + +def replace_entity(tmpl, value, tags): + result = tmpl + for t in tags: + result = result.replace(t, value) + return result + + +@app.command() +def export_nlu_conv_json( + conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True), + conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True), + nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True), +): + from .utils import ExtendedPath + from random import sample + + entity_samples = nlu_samples_reader(nlu_data_file) + conv_data = ExtendedPath(conv_src).read_json() + conv_data["Dates"] = generate_dates() + result_dict = {} + data_count = 0 + for entity_name, entity_template_tags, entity_data in nlu_entity_reader( + nlu_data_file + ): + entity_variants = sample(conv_data[entity_name], entity_samples[entity_name]["test_size"]) + unique_entites = unique_entity_list(entity_template_tags, entity_data) + # sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"]) + result_dict[entity_name] = [] + for val in entity_variants: + sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"]) + for tmpl in sample_entites: + result = replace_entity(tmpl, val, entity_template_tags) + result_dict[entity_name].append(result) + data_count += 1 + print(f"Total of {data_count} variants generated") + ExtendedPath(conv_dest).write_json(result_dict) + + +def main(): + app() + + +if __name__ == "__main__": + main() diff --git a/jasper/data/unique_nlu.py b/jasper/data/unique_nlu.py deleted file mode 100644 index f2446db..0000000 --- a/jasper/data/unique_nlu.py +++ /dev/null @@ -1,92 +0,0 @@ -import pandas as pd - - -def compute_pnr_name_city(): - data = pd.read_csv("./customer_utterance_processing/customer_provide_answer.csv") - - def unique_pnr_count(): - pnr_data = data[data["Input.Answer"] == "ZZZZZZ"] - unique_pnr_set = { - t - for n in range(1, 5) - for t in pnr_data[f"Answer.utterance-{n}"].tolist() - if "ZZZZZZ" in t - } - return len(unique_pnr_set) - - def unique_name_count(): - pnr_data = data[data["Input.Answer"] == "John Doe"] - unique_pnr_set = { - t - for n in range(1, 5) - for t in pnr_data[f"Answer.utterance-{n}"].tolist() - if "John Doe" in t - } - return len(unique_pnr_set) - - def unique_city_count(): - pnr_data = data[data["Input.Answer"] == "Heathrow Airport"] - unique_pnr_set = { - t - for n in range(1, 5) - for t in pnr_data[f"Answer.utterance-{n}"].tolist() - if "Heathrow Airport" in t - } - return len(unique_pnr_set) - - def unique_entity_count(entity_template_tags): - # entity_data = data[data['Input.Prompt'] == entity_template_tag] - entity_data = data - unique_entity_set = { - t - for n in range(1, 5) - for t in entity_data[f"Answer.utterance-{n}"].tolist() - if any(et in t for et in entity_template_tags) - } - return len(unique_entity_set) - - print('PNR', unique_pnr_count()) - print('Name', unique_name_count()) - print('City', unique_city_count()) - print('Payment', unique_entity_count(['KPay', 'ZPay', 'Credit Card'])) - - -def compute_date(): - entity_template_tags = ['27 january', 'December 18'] - data = pd.read_csv("./customer_utterance_processing/customer_provide_departure.csv") - # data.sample(10) - - def unique_entity_count(entity_template_tags): - # entity_data = data[data['Input.Prompt'] == entity_template_tag] - entity_data = data - unique_entity_set = { - t - for n in range(1, 5) - for t in entity_data[f"Answer.utterance-{n}"].tolist() - if any(et in t for et in entity_template_tags) - } - return len(unique_entity_set) - - print('Date', unique_entity_count(entity_template_tags)) - - -def compute_option(): - entity_template_tag = 'third' - data = pd.read_csv("./customer_utterance_processing/customer_provide_flight_selection.csv") - - def unique_entity_count(): - entity_data = data[data['Input.Prompt'] == entity_template_tag] - unique_entity_set = { - t - for n in range(1, 5) - for t in entity_data[f"Answer.utterance-{n}"].tolist() - if entity_template_tag in t - } - return len(unique_entity_set) - - print('Option', unique_entity_count()) - - -compute_pnr_name_city() -compute_date() -compute_option() diff --git a/jasper/data/utils.py b/jasper/data/utils.py index fa3d85d..1e90f00 100644 --- a/jasper/data/utils.py +++ b/jasper/data/utils.py @@ -1,13 +1,17 @@ -import numpy as np -import wave import io import os import json +import wave from pathlib import Path +from itertools import product +from functools import partial +from math import floor +from uuid import uuid4 +from concurrent.futures import ThreadPoolExecutor +import numpy as np import pymongo from slugify import slugify -from uuid import uuid4 from num2words import num2words from jasper.client import transcribe_gen from nemo.collections.asr.metrics import word_error_rate @@ -15,8 +19,6 @@ import matplotlib.pyplot as plt import librosa import librosa.display from tqdm import tqdm -from functools import partial -from concurrent.futures import ThreadPoolExecutor def manifest_str(path, dur, text): @@ -238,6 +240,44 @@ def plot_seg(wav_plot_path, audio_path): fig.savefig(wav_plot_f, format="png", dpi=50) +def generate_dates(): + + days = [i for i in range(1, 32)] + months = [ + "January", + "February", + "March", + "April", + "May", + "June", + "July", + "August", + "September", + "October", + "November", + "December", + ] + # ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement + + def ordinal(n): + return "%d%s" % ( + n, + "tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4], + ) + + def canon_vars(d, m): + return [ + ordinal(d) + " " + m, + m + " " + ordinal(d), + ordinal(d) + " of " + m, + m + " the " + ordinal(d), + str(d) + " " + m, + m + " " + str(d), + ] + + return [dm for d, m in product(days, months) for dm in canon_vars(d, m)] + + def main(): for c in random_pnr_generator(): print(c) diff --git a/jasper/data/validation/process.py b/jasper/data/validation/process.py index 750e96b..aa21ba4 100644 --- a/jasper/data/validation/process.py +++ b/jasper/data/validation/process.py @@ -181,14 +181,16 @@ def task_ui( @app.command() def dump_corrections( + task_uid: str, data_name: str = typer.Option("call_alphanum", show_default=True), dump_dir: Path = Path("./data/asr_data"), dump_fname: Path = Path("corrections.json"), ): dump_path = dump_dir / Path(data_name) / dump_fname col = get_mongo_conn(col="asr_validation") - - cursor_obj = col.find({"type": "correction"}, projection={"_id": False}) + task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0] + corrections = list(col.find({"type": "correction"}, projection={"_id": False})) + cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False}) corrections = [c for c in cursor_obj] ExtendedPath(dump_path).write_json(corrections) @@ -257,6 +259,7 @@ class ExtractionType(str, Enum): date = "dates" city = "cities" name = "names" + all = "all" @app.command() @@ -267,50 +270,58 @@ def split_extract( dump_dir: Path = Path("./data/asr_data"), dump_file: Path = Path("ui_dump.json"), manifest_file: Path = Path("manifest.json"), - corrections_file: Path = Path("corrections.json"), + corrections_file: str = typer.Option("corrections.json", show_default=True), conv_data_path: Path = Path("./data/conv_data.json"), - extraction_type: ExtractionType = ExtractionType.date, + extraction_type: ExtractionType = ExtractionType.all, ): import shutil - dest_data_name = data_name + "_" + extraction_type.value data_manifest_path = dump_dir / Path(data_name) / manifest_file conv_data = ExtendedPath(conv_data_path).read_json() - extraction_vals = conv_data[extraction_type.value] - manifest_gen = asr_manifest_reader(data_manifest_path) - dest_data_dir = dump_dir / Path(dest_data_name) - dest_data_dir.mkdir(exist_ok=True, parents=True) - (dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True) - dest_manifest_path = dest_data_dir / manifest_file - dest_ui_path = dest_data_dir / dump_file - dest_correction_path = dest_data_dir / corrections_file + def extract_data_of_type(extraction_key): + extraction_vals = conv_data[extraction_key] + dest_data_name = data_name + "_" + extraction_key.lower() - def extract_manifest(mg): - for m in mg: - if m["text"] in extraction_vals: - shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"])) - yield m + manifest_gen = asr_manifest_reader(data_manifest_path) + dest_data_dir = dump_dir / Path(dest_data_name) + dest_data_dir.mkdir(exist_ok=True, parents=True) + (dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True) + dest_manifest_path = dest_data_dir / manifest_file + dest_ui_path = dest_data_dir / dump_file - asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen)) + def extract_manifest(mg): + for m in mg: + if m["text"] in extraction_vals: + shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"])) + yield m - ui_data_path = dump_dir / Path(data_name) / dump_file - corrections_path = dump_dir / Path(data_name) / corrections_file - ui_data = json.load(ui_data_path.open())["data"] - file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data} - corrections = json.load(corrections_path.open()) + asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen)) - extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data)) - ExtendedPath(dest_ui_path).write_json(extracted_ui_data) + ui_data_path = dump_dir / Path(data_name) / dump_file + ui_data = json.load(ui_data_path.open())["data"] + file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data} + extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data)) + ExtendedPath(dest_ui_path).write_json(extracted_ui_data) - extracted_corrections = list( - filter( - lambda c: c["code"] in file_ui_map - and file_ui_map[c["code"]]["text"] in extraction_vals, - corrections, - ) - ) - ExtendedPath(dest_correction_path).write_json(extracted_corrections) + if corrections_file: + dest_correction_path = dest_data_dir / corrections_file + corrections_path = dump_dir / Path(data_name) / corrections_file + corrections = json.load(corrections_path.open()) + extracted_corrections = list( + filter( + lambda c: c["code"] in file_ui_map + and file_ui_map[c["code"]]["text"] in extraction_vals, + corrections, + ) + ) + ExtendedPath(dest_correction_path).write_json(extracted_corrections) + + if extraction_type.value == 'all': + for ext_key in conv_data.keys(): + extract_data_of_type(ext_key) + else: + extract_data_of_type(extraction_type.value) @app.command() diff --git a/setup.py b/setup.py index e879af2..d1c85d4 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,7 @@ setup( "jasper_trainer = jasper.training.cli:main", "jasper_data_tts_generate = jasper.data.tts_generator:main", "jasper_data_conv_generate = jasper.data.conv_generator:main", + "jasper_data_nlu_generate = jasper.data.nlu_generator:main", "jasper_data_call_recycle = jasper.data.call_recycler:main", "jasper_data_asr_recycle = jasper.data.asr_recycler:main", "jasper_data_rev_recycle = jasper.data.rev_recycler:main",