diff --git a/jasper/data/utils.py b/jasper/data/utils.py index 1f5f5b1..eda1a65 100644 --- a/jasper/data/utils.py +++ b/jasper/data/utils.py @@ -95,10 +95,12 @@ class ExtendedPath(type(Path())): """docstring for ExtendedPath.""" def read_json(self): + print(f'reading json from {self}') with self.open("r") as jf: return json.load(jf) def write_json(self, data): + print(f'writing json to {self}') self.parent.mkdir(parents=True, exist_ok=True) with self.open("w") as jf: return json.dump(data, jf, indent=2) diff --git a/jasper/data/validation/process.py b/jasper/data/validation/process.py index babd49c..0f73480 100644 --- a/jasper/data/validation/process.py +++ b/jasper/data/validation/process.py @@ -1,6 +1,7 @@ import json import shutil from pathlib import Path +from enum import Enum import typer from tqdm import tqdm @@ -176,6 +177,81 @@ def fill_unannotated( ) +class ExtractionType(str, Enum): + date = "dates" + city = "cities" + name = "names" + + +@app.command() +def split_extract( + data_name: str = typer.Option("call_alphanum", show_default=True), + # dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True), + dump_dir: Path = Path("./data/valiation_data"), + dump_file: Path = Path("ui_dump.json"), + manifest_dir: Path = Path("./data/asr_data"), + manifest_file: Path = Path("manifest.json"), + corrections_file: Path = Path("corrections.json"), + conv_data_path: Path = Path("./data/conv_data.json"), + extraction_type: ExtractionType = ExtractionType.date, +): + import shutil + + def get_conv_data(cdp): + from itertools import product + + conv_data = json.load(cdp.open()) + days = [str(i) for i in range(1, 32)] + months = conv_data["months"] + day_months = {d + " " + m for d, m in product(days, months)} + return { + "cities": set(conv_data["cities"]), + "names": set(conv_data["names"]), + "dates": day_months, + } + + dest_data_name = data_name + "_" + extraction_type.value + data_manifest_path = manifest_dir / Path(data_name) / manifest_file + conv_data = get_conv_data(conv_data_path) + extraction_vals = conv_data[extraction_type.value] + + manifest_gen = asr_manifest_reader(data_manifest_path) + dest_data_dir = manifest_dir / Path(dest_data_name) + dest_data_dir.mkdir(exist_ok=True, parents=True) + (dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True) + dest_manifest_path = dest_data_dir / manifest_file + dest_ui_dir = dump_dir / Path(dest_data_name) + dest_ui_dir.mkdir(exist_ok=True, parents=True) + dest_ui_path = dest_ui_dir / dump_file + dest_correction_path = dest_ui_dir / corrections_file + + def extract_manifest(mg): + for m in mg: + if m["text"] in extraction_vals: + shutil.copy(m["audio_path"], dest_data_dir / Path(m["audio_filepath"])) + yield m + + asr_manifest_writer(dest_manifest_path, extract_manifest(manifest_gen)) + + ui_data_path = dump_dir / Path(data_name) / dump_file + corrections_path = dump_dir / Path(data_name) / corrections_file + ui_data = json.load(ui_data_path.open())["data"] + file_ui_map = {Path(u["audio_filepath"]).stem: u for u in ui_data} + corrections = json.load(corrections_path.open()) + + extracted_ui_data = list(filter(lambda u: u["text"] in extraction_vals, ui_data)) + ExtendedPath(dest_ui_path).write_json(extracted_ui_data) + + extracted_corrections = list( + filter( + lambda c: c["code"] in file_ui_map + and file_ui_map[c["code"]]["text"] in extraction_vals, + corrections, + ) + ) + ExtendedPath(dest_correction_path).write_json(extracted_corrections) + + @app.command() def update_corrections( data_name: str = typer.Option("call_alphanum", show_default=True), @@ -188,7 +264,7 @@ def update_corrections( skip_incorrect: bool = True, ): data_manifest_path = manifest_dir / Path(data_name) / manifest_file - corrections_path = manifest_dir / Path(data_name) / corrections_file + corrections_path = dump_dir / Path(data_name) / corrections_file def correct_manifest(manifest_data_gen, corrections_path): corrections = json.load(corrections_path.open())