mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-03-08 02:22:34 +00:00
1. split extract all data types in one shot with --extraction-type all flag
2. add notes about diffing split extracted and original data 3. add a nlu conv generator to generate conv data based on nlu utterances and entities 4. add task uid support for dumping corrections 5. abstracted generate date fn
This commit is contained in:
5
Notes.md
Normal file
5
Notes.md
Normal file
@@ -0,0 +1,5 @@
|
||||
|
||||
> Diff after splitting based on type
|
||||
```
|
||||
diff <(cat data/asr_data/call_upwork_test_cnd_*/manifest.json |sort) <(cat data/asr_data/call_upwork_test_cnd/manifest.json |sort)
|
||||
```
|
||||
@@ -1,8 +1,6 @@
|
||||
import typer
|
||||
from pathlib import Path
|
||||
from random import randrange
|
||||
from itertools import product
|
||||
from math import floor
|
||||
from .utils import generate_dates
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@@ -16,46 +14,7 @@ def export_conv_json(
|
||||
|
||||
conv_data = ExtendedPath(conv_src).read_json()
|
||||
|
||||
days = [i for i in range(1, 32)]
|
||||
months = [
|
||||
"January",
|
||||
"February",
|
||||
"March",
|
||||
"April",
|
||||
"May",
|
||||
"June",
|
||||
"July",
|
||||
"August",
|
||||
"September",
|
||||
"October",
|
||||
"November",
|
||||
"December",
|
||||
]
|
||||
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
|
||||
|
||||
def ordinal(n):
|
||||
return "%d%s" % (
|
||||
n,
|
||||
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
|
||||
)
|
||||
|
||||
def canon_vars(d, m):
|
||||
return [
|
||||
ordinal(d) + " " + m,
|
||||
m + " " + ordinal(d),
|
||||
ordinal(d) + " of " + m,
|
||||
m + " the " + ordinal(d),
|
||||
str(d) + " " + m,
|
||||
m + " " + str(d),
|
||||
]
|
||||
|
||||
day_months = [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
|
||||
|
||||
conv_data["dates"] = day_months
|
||||
|
||||
def dates_data_gen():
|
||||
i = randrange(len(day_months))
|
||||
return day_months[i]
|
||||
conv_data["dates"] = generate_dates()
|
||||
|
||||
ExtendedPath(conv_dest).write_json(conv_data)
|
||||
|
||||
|
||||
98
jasper/data/nlu_generator.py
Normal file
98
jasper/data/nlu_generator.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
import pandas as pd
|
||||
from ruamel.yaml import YAML
|
||||
from itertools import product
|
||||
from .utils import generate_dates
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def unique_entity_list(entity_template_tags, entity_data):
|
||||
unique_entity_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in entity_data[f"Answer.utterance-{n}"].tolist()
|
||||
if any(et in t for et in entity_template_tags)
|
||||
}
|
||||
return list(unique_entity_set)
|
||||
|
||||
|
||||
def nlu_entity_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
|
||||
yaml = YAML()
|
||||
nlu_data = yaml.load(nlu_data_file.read_text())
|
||||
for cf in nlu_data["csv_files"]:
|
||||
data = pd.read_csv(cf["fname"])
|
||||
for et in cf["entities"]:
|
||||
entity_name = et["name"]
|
||||
entity_template_tags = et["tags"]
|
||||
if "filter" in et:
|
||||
entity_data = data[data[cf["filter_key"]] == et["filter"]]
|
||||
else:
|
||||
entity_data = data
|
||||
yield entity_name, entity_template_tags, entity_data
|
||||
|
||||
|
||||
def nlu_samples_reader(nlu_data_file: Path = Path("./nlu_data.yaml")):
|
||||
yaml = YAML()
|
||||
nlu_data = yaml.load(nlu_data_file.read_text())
|
||||
sm = {s["name"]: s for s in nlu_data["samples_per_entity"]}
|
||||
return sm
|
||||
|
||||
|
||||
@app.command()
|
||||
def compute_unique_nlu_stats(
|
||||
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
|
||||
):
|
||||
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
|
||||
nlu_data_file
|
||||
):
|
||||
entity_count = len(unique_entity_list(entity_template_tags, entity_data))
|
||||
print(f"{entity_name}\t{entity_count}")
|
||||
|
||||
|
||||
def replace_entity(tmpl, value, tags):
|
||||
result = tmpl
|
||||
for t in tags:
|
||||
result = result.replace(t, value)
|
||||
return result
|
||||
|
||||
|
||||
@app.command()
|
||||
def export_nlu_conv_json(
|
||||
conv_src: Path = typer.Option(Path("./conv_data.json"), show_default=True),
|
||||
conv_dest: Path = typer.Option(Path("./data/conv_data.json"), show_default=True),
|
||||
nlu_data_file: Path = typer.Option(Path("./nlu_data.yaml"), show_default=True),
|
||||
):
|
||||
from .utils import ExtendedPath
|
||||
from random import sample
|
||||
|
||||
entity_samples = nlu_samples_reader(nlu_data_file)
|
||||
conv_data = ExtendedPath(conv_src).read_json()
|
||||
conv_data["Dates"] = generate_dates()
|
||||
result_dict = {}
|
||||
data_count = 0
|
||||
for entity_name, entity_template_tags, entity_data in nlu_entity_reader(
|
||||
nlu_data_file
|
||||
):
|
||||
entity_variants = sample(conv_data[entity_name], entity_samples[entity_name]["test_size"])
|
||||
unique_entites = unique_entity_list(entity_template_tags, entity_data)
|
||||
# sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
|
||||
result_dict[entity_name] = []
|
||||
for val in entity_variants:
|
||||
sample_entites = sample(unique_entites, entity_samples[entity_name]["samples"])
|
||||
for tmpl in sample_entites:
|
||||
result = replace_entity(tmpl, val, entity_template_tags)
|
||||
result_dict[entity_name].append(result)
|
||||
data_count += 1
|
||||
print(f"Total of {data_count} variants generated")
|
||||
ExtendedPath(conv_dest).write_json(result_dict)
|
||||
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,92 +0,0 @@
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def compute_pnr_name_city():
|
||||
data = pd.read_csv("./customer_utterance_processing/customer_provide_answer.csv")
|
||||
|
||||
def unique_pnr_count():
|
||||
pnr_data = data[data["Input.Answer"] == "ZZZZZZ"]
|
||||
unique_pnr_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
|
||||
if "ZZZZZZ" in t
|
||||
}
|
||||
return len(unique_pnr_set)
|
||||
|
||||
def unique_name_count():
|
||||
pnr_data = data[data["Input.Answer"] == "John Doe"]
|
||||
unique_pnr_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
|
||||
if "John Doe" in t
|
||||
}
|
||||
return len(unique_pnr_set)
|
||||
|
||||
def unique_city_count():
|
||||
pnr_data = data[data["Input.Answer"] == "Heathrow Airport"]
|
||||
unique_pnr_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in pnr_data[f"Answer.utterance-{n}"].tolist()
|
||||
if "Heathrow Airport" in t
|
||||
}
|
||||
return len(unique_pnr_set)
|
||||
|
||||
def unique_entity_count(entity_template_tags):
|
||||
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
|
||||
entity_data = data
|
||||
unique_entity_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in entity_data[f"Answer.utterance-{n}"].tolist()
|
||||
if any(et in t for et in entity_template_tags)
|
||||
}
|
||||
return len(unique_entity_set)
|
||||
|
||||
print('PNR', unique_pnr_count())
|
||||
print('Name', unique_name_count())
|
||||
print('City', unique_city_count())
|
||||
print('Payment', unique_entity_count(['KPay', 'ZPay', 'Credit Card']))
|
||||
|
||||
|
||||
def compute_date():
|
||||
entity_template_tags = ['27 january', 'December 18']
|
||||
data = pd.read_csv("./customer_utterance_processing/customer_provide_departure.csv")
|
||||
# data.sample(10)
|
||||
|
||||
def unique_entity_count(entity_template_tags):
|
||||
# entity_data = data[data['Input.Prompt'] == entity_template_tag]
|
||||
entity_data = data
|
||||
unique_entity_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in entity_data[f"Answer.utterance-{n}"].tolist()
|
||||
if any(et in t for et in entity_template_tags)
|
||||
}
|
||||
return len(unique_entity_set)
|
||||
|
||||
print('Date', unique_entity_count(entity_template_tags))
|
||||
|
||||
|
||||
def compute_option():
|
||||
entity_template_tag = 'third'
|
||||
data = pd.read_csv("./customer_utterance_processing/customer_provide_flight_selection.csv")
|
||||
|
||||
def unique_entity_count():
|
||||
entity_data = data[data['Input.Prompt'] == entity_template_tag]
|
||||
unique_entity_set = {
|
||||
t
|
||||
for n in range(1, 5)
|
||||
for t in entity_data[f"Answer.utterance-{n}"].tolist()
|
||||
if entity_template_tag in t
|
||||
}
|
||||
return len(unique_entity_set)
|
||||
|
||||
print('Option', unique_entity_count())
|
||||
|
||||
|
||||
compute_pnr_name_city()
|
||||
compute_date()
|
||||
compute_option()
|
||||
@@ -1,13 +1,17 @@
|
||||
import numpy as np
|
||||
import wave
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from itertools import product
|
||||
from functools import partial
|
||||
from math import floor
|
||||
from uuid import uuid4
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
import pymongo
|
||||
from slugify import slugify
|
||||
from uuid import uuid4
|
||||
from num2words import num2words
|
||||
from jasper.client import transcribe_gen
|
||||
from nemo.collections.asr.metrics import word_error_rate
|
||||
@@ -15,8 +19,6 @@ import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
def manifest_str(path, dur, text):
|
||||
@@ -238,6 +240,44 @@ def plot_seg(wav_plot_path, audio_path):
|
||||
fig.savefig(wav_plot_f, format="png", dpi=50)
|
||||
|
||||
|
||||
def generate_dates():
|
||||
|
||||
days = [i for i in range(1, 32)]
|
||||
months = [
|
||||
"January",
|
||||
"February",
|
||||
"March",
|
||||
"April",
|
||||
"May",
|
||||
"June",
|
||||
"July",
|
||||
"August",
|
||||
"September",
|
||||
"October",
|
||||
"November",
|
||||
"December",
|
||||
]
|
||||
# ordinal from https://stackoverflow.com/questions/9647202/ordinal-numbers-replacement
|
||||
|
||||
def ordinal(n):
|
||||
return "%d%s" % (
|
||||
n,
|
||||
"tsnrhtdd"[(floor(n / 10) % 10 != 1) * (n % 10 < 4) * n % 10 :: 4],
|
||||
)
|
||||
|
||||
def canon_vars(d, m):
|
||||
return [
|
||||
ordinal(d) + " " + m,
|
||||
m + " " + ordinal(d),
|
||||
ordinal(d) + " of " + m,
|
||||
m + " the " + ordinal(d),
|
||||
str(d) + " " + m,
|
||||
m + " " + str(d),
|
||||
]
|
||||
|
||||
return [dm for d, m in product(days, months) for dm in canon_vars(d, m)]
|
||||
|
||||
|
||||
def main():
|
||||
for c in random_pnr_generator():
|
||||
print(c)
|
||||
|
||||
@@ -181,14 +181,16 @@ def task_ui(
|
||||
|
||||
@app.command()
|
||||
def dump_corrections(
|
||||
task_uid: str,
|
||||
data_name: str = typer.Option("call_alphanum", show_default=True),
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_fname: Path = Path("corrections.json"),
|
||||
):
|
||||
dump_path = dump_dir / Path(data_name) / dump_fname
|
||||
col = get_mongo_conn(col="asr_validation")
|
||||
|
||||
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
|
||||
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
|
||||
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
|
||||
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
|
||||
corrections = [c for c in cursor_obj]
|
||||
ExtendedPath(dump_path).write_json(corrections)
|
||||
|
||||
@@ -257,6 +259,7 @@ class ExtractionType(str, Enum):
|
||||
date = "dates"
|
||||
city = "cities"
|
||||
name = "names"
|
||||
all = "all"
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -267,50 +270,58 @@ def split_extract(
|
||||
dump_dir: Path = Path("./data/asr_data"),
|
||||
dump_file: Path = Path("ui_dump.json"),
|
||||
manifest_file: Path = Path("manifest.json"),
|
||||
corrections_file: Path = Path("corrections.json"),
|
||||
corrections_file: str = typer.Option("corrections.json", show_default=True),
|
||||
conv_data_path: Path = Path("./data/conv_data.json"),
|
||||
extraction_type: ExtractionType = ExtractionType.date,
|
||||
extraction_type: ExtractionType = ExtractionType.all,
|
||||
):
|
||||
import shutil
|
||||
|
||||
dest_data_name = data_name + "_" + extraction_type.value
|
||||
data_manifest_path = dump_dir / Path(data_name) / manifest_file
|
||||
conv_data = ExtendedPath(conv_data_path).read_json()
|
||||
extraction_vals = conv_data[extraction_type.value]
|
||||
|
||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
dest_data_dir = dump_dir / Path(dest_data_name)
|
||||
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
||||
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
||||
dest_manifest_path = dest_data_dir / manifest_file
|
||||
dest_ui_path = dest_data_dir / dump_file
|
||||
dest_correction_path = dest_data_dir / corrections_file
|
||||
def extract_data_of_type(extraction_key):
|
||||
extraction_vals = conv_data[extraction_key]
|
||||
dest_data_name = data_name + "_" + extraction_key.lower()
|
||||
|
||||
def extract_manifest(mg):
|
||||
for m in mg:
|
||||
if m["text"] in extraction_vals:
|
||||
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
|
||||
yield m
|
||||
manifest_gen = asr_manifest_reader(data_manifest_path)
|
||||
dest_data_dir = dump_dir / Path(dest_data_name)
|
||||
dest_data_dir.mkdir(exist_ok=True, parents=True)
|
||||
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
|
||||
dest_manifest_path = dest_data_dir / manifest_file
|
||||
dest_ui_path = dest_data_dir / dump_file
|
||||
|
||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||
def extract_manifest(mg):
|
||||
for m in mg:
|
||||
if m["text"] in extraction_vals:
|
||||
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
|
||||
yield m
|
||||
|
||||
ui_data_path = dump_dir / Path(data_name) / dump_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
ui_data = json.load(ui_data_path.open())["data"]
|
||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||
corrections = json.load(corrections_path.open())
|
||||
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
|
||||
|
||||
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
||||
ExtendedPath(dest_ui_path).write_json(extracted_ui_data)
|
||||
ui_data_path = dump_dir / Path(data_name) / dump_file
|
||||
ui_data = json.load(ui_data_path.open())["data"]
|
||||
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
|
||||
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
|
||||
ExtendedPath(dest_ui_path).write_json(extracted_ui_data)
|
||||
|
||||
extracted_corrections = list(
|
||||
filter(
|
||||
lambda c: c["code"] in file_ui_map
|
||||
and file_ui_map[c["code"]]["text"] in extraction_vals,
|
||||
corrections,
|
||||
)
|
||||
)
|
||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
||||
if corrections_file:
|
||||
dest_correction_path = dest_data_dir / corrections_file
|
||||
corrections_path = dump_dir / Path(data_name) / corrections_file
|
||||
corrections = json.load(corrections_path.open())
|
||||
extracted_corrections = list(
|
||||
filter(
|
||||
lambda c: c["code"] in file_ui_map
|
||||
and file_ui_map[c["code"]]["text"] in extraction_vals,
|
||||
corrections,
|
||||
)
|
||||
)
|
||||
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
|
||||
|
||||
if extraction_type.value == 'all':
|
||||
for ext_key in conv_data.keys():
|
||||
extract_data_of_type(ext_key)
|
||||
else:
|
||||
extract_data_of_type(extraction_type.value)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
||||
1
setup.py
1
setup.py
@@ -65,6 +65,7 @@ setup(
|
||||
"jasper_trainer = jasper.training.cli:main",
|
||||
"jasper_data_tts_generate = jasper.data.tts_generator:main",
|
||||
"jasper_data_conv_generate = jasper.data.conv_generator:main",
|
||||
"jasper_data_nlu_generate = jasper.data.nlu_generator:main",
|
||||
"jasper_data_call_recycle = jasper.data.call_recycler:main",
|
||||
"jasper_data_asr_recycle = jasper.data.asr_recycler:main",
|
||||
"jasper_data_rev_recycle = jasper.data.rev_recycler:main",
|
||||
|
||||
Reference in New Issue
Block a user