1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-03-09 19:02:35 +00:00

1. split extract all data types in one shot with --extraction-type all flag

2. add notes about diffing split extracted and original data
3. add a nlu conv generator to generate conv data based on nlu utterances and entities
4. add task uid support for dumping corrections
5. abstracted generate date fn
This commit is contained in:
2020-06-24 22:50:45 +05:30
parent e76ccda5dd
commit 515e9c1037
7 changed files with 196 additions and 174 deletions

View File

@@ -181,14 +181,16 @@ def task_ui(
@app.command()
def dump_corrections(
task_uid: str,
data_name: str = typer.Option("call_alphanum", show_default=True),
dump_dir: Path = Path("./data/asr_data"),
dump_fname: Path = Path("corrections.json"),
):
dump_path = dump_dir / Path(data_name) / dump_fname
col = get_mongo_conn(col="asr_validation")
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})
task_id = [c for c in col.distinct("task_id") if c.rsplit("-", 1)[1] == task_uid][0]
corrections = list(col.find({"type": "correction"}, projection={"_id": False}))
cursor_obj = col.find({"type": "correction", "task_id": task_id}, projection={"_id": False})
corrections = [c for c in cursor_obj]
ExtendedPath(dump_path).write_json(corrections)
@@ -257,6 +259,7 @@ class ExtractionType(str, Enum):
date = "dates"
city = "cities"
name = "names"
all = "all"
@app.command()
@@ -267,50 +270,58 @@ def split_extract(
dump_dir: Path = Path("./data/asr_data"),
dump_file: Path = Path("ui_dump.json"),
manifest_file: Path = Path("manifest.json"),
corrections_file: Path = Path("corrections.json"),
corrections_file: str = typer.Option("corrections.json", show_default=True),
conv_data_path: Path = Path("./data/conv_data.json"),
extraction_type: ExtractionType = ExtractionType.date,
extraction_type: ExtractionType = ExtractionType.all,
):
import shutil
dest_data_name = data_name + "_" + extraction_type.value
data_manifest_path = dump_dir / Path(data_name) / manifest_file
conv_data = ExtendedPath(conv_data_path).read_json()
extraction_vals = conv_data[extraction_type.value]
manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = dump_dir / Path(dest_data_name)
dest_data_dir.mkdir(exist_ok=True, parents=True)
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file
dest_ui_path = dest_data_dir / dump_file
dest_correction_path = dest_data_dir / corrections_file
def extract_data_of_type(extraction_key):
extraction_vals = conv_data[extraction_key]
dest_data_name = data_name + "_" + extraction_key.lower()
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
yield m
manifest_gen = asr_manifest_reader(data_manifest_path)
dest_data_dir = dump_dir / Path(dest_data_name)
dest_data_dir.mkdir(exist_ok=True, parents=True)
(dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True)
dest_manifest_path = dest_data_dir / manifest_file
dest_ui_path = dest_data_dir / dump_file
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
def extract_manifest(mg):
for m in mg:
if m["text"] in extraction_vals:
shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"]))
yield m
ui_data_path = dump_dir / Path(data_name) / dump_file
corrections_path = dump_dir / Path(data_name) / corrections_file
ui_data = json.load(ui_data_path.open())["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
corrections = json.load(corrections_path.open())
asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen))
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
ExtendedPath(dest_ui_path).write_json(extracted_ui_data)
ui_data_path = dump_dir / Path(data_name) / dump_file
ui_data = json.load(ui_data_path.open())["data"]
file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data}
extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data))
ExtendedPath(dest_ui_path).write_json(extracted_ui_data)
extracted_corrections = list(
filter(
lambda c: c["code"] in file_ui_map
and file_ui_map[c["code"]]["text"] in extraction_vals,
corrections,
)
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if corrections_file:
dest_correction_path = dest_data_dir / corrections_file
corrections_path = dump_dir / Path(data_name) / corrections_file
corrections = json.load(corrections_path.open())
extracted_corrections = list(
filter(
lambda c: c["code"] in file_ui_map
and file_ui_map[c["code"]]["text"] in extraction_vals,
corrections,
)
)
ExtendedPath(dest_correction_path).write_json(extracted_corrections)
if extraction_type.value == 'all':
for ext_key in conv_data.keys():
extract_data_of_type(ext_key)
else:
extract_data_of_type(extraction_type.value)
@app.command()