diff --git a/jasper/data/process.py b/jasper/data/process.py index 7a53030..0db22af 100644 --- a/jasper/data/process.py +++ b/jasper/data/process.py @@ -23,7 +23,7 @@ def fixate_data(dataset_path: Path): @app.command() -def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path): +def augment_data(src_dataset_paths: List[Path], dest_dataset_path: Path): reader_list = [] abs_manifest_path = Path("abs_manifest.json") for dataset_path in src_dataset_paths: diff --git a/jasper/data/utils.py b/jasper/data/utils.py index d4d1c7f..fa3d85d 100644 --- a/jasper/data/utils.py +++ b/jasper/data/utils.py @@ -59,6 +59,10 @@ def alnum_to_asr_tokens(text): return ("".join(num_tokens)).lower() +def tscript_uuid_fname(transcript): + return str(uuid4()) + "_" + slugify(transcript, max_length=8) + + def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False): dataset_dir = output_dir / Path(dataset_name) (dataset_dir / Path("wav")).mkdir(parents=True, exist_ok=True) @@ -67,7 +71,7 @@ def asr_data_writer(output_dir, dataset_name, asr_data_source, verbose=False): with asr_manifest.open("w") as mf: print(f"writing manifest to {asr_manifest}") for transcript, audio_dur, wav_data in asr_data_source: - fname = str(uuid4()) + "_" + slugify(transcript, max_length=8) + fname = tscript_uuid_fname(transcript) audio_file = dataset_dir / Path("wav") / Path(fname).with_suffix(".wav") audio_file.write_bytes(wav_data) rel_pnr_path = audio_file.relative_to(dataset_dir) @@ -174,7 +178,7 @@ def asr_manifest_reader(data_manifest_path: Path): pnr_data = [json.loads(v) for v in pnr_jsonl] for p in pnr_data: p["audio_path"] = data_manifest_path.parent / Path(p["audio_filepath"]) - p["chars"] = Path(p["audio_filepath"]).stem + p["text"] = p["text"].strip() yield p diff --git a/jasper/data/validation/process.py b/jasper/data/validation/process.py index a8b111c..750e96b 100644 --- a/jasper/data/validation/process.py +++ b/jasper/data/validation/process.py @@ -11,6 +11,7 @@ from ..utils import ( ExtendedPath, asr_manifest_reader, asr_manifest_writer, + tscript_uuid_fname, get_mongo_conn, plot_seg, ) @@ -262,9 +263,9 @@ class ExtractionType(str, Enum): def split_extract( data_name: str = typer.Option("call_alphanum", show_default=True), # dest_data_name: str = typer.Option("call_aldata_namephanum_date", show_default=True), - dump_dir: Path = Path("./data/valiation_data"), + # dump_dir: Path = Path("./data/valiation_data"), + dump_dir: Path = Path("./data/asr_data"), dump_file: Path = Path("ui_dump.json"), - manifest_dir: Path = Path("./data/asr_data"), manifest_file: Path = Path("manifest.json"), corrections_file: Path = Path("corrections.json"), conv_data_path: Path = Path("./data/conv_data.json"), @@ -272,33 +273,18 @@ def split_extract( ): import shutil - def get_conv_data(cdp): - from itertools import product - - conv_data = json.load(cdp.open()) - days = [str(i) for i in range(1, 32)] - months = conv_data["months"] - day_months = {d + " " + m for d, m in product(days, months)} - return { - "cities": set(conv_data["cities"]), - "names": set(conv_data["names"]), - "dates": day_months, - } - dest_data_name = data_name + "_" + extraction_type.value - data_manifest_path = manifest_dir / Path(data_name) / manifest_file - conv_data = get_conv_data(conv_data_path) + data_manifest_path = dump_dir / Path(data_name) / manifest_file + conv_data = ExtendedPath(conv_data_path).read_json() extraction_vals = conv_data[extraction_type.value] manifest_gen = asr_manifest_reader(data_manifest_path) - dest_data_dir = manifest_dir / Path(dest_data_name) + dest_data_dir = dump_dir / Path(dest_data_name) dest_data_dir.mkdir(exist_ok=True, parents=True) (dest_data_dir / Path("wav")).mkdir(exist_ok=True, parents=True) dest_manifest_path = dest_data_dir / manifest_file - dest_ui_dir = dump_dir / Path(dest_data_name) - dest_ui_dir.mkdir(exist_ok=True, parents=True) - dest_ui_path = dest_ui_dir / dump_file - dest_correction_path = dest_ui_dir / corrections_file + dest_ui_path = dest_data_dir / dump_file + dest_correction_path = dest_data_dir / corrections_file def extract_manifest(mg): for m in mg: @@ -330,19 +316,19 @@ def split_extract( @app.command() def update_corrections( data_name: str = typer.Option("call_alphanum", show_default=True), - dump_dir: Path = Path("./data/valiation_data"), - manifest_dir: Path = Path("./data/asr_data"), + dump_dir: Path = Path("./data/asr_data"), manifest_file: Path = Path("manifest.json"), corrections_file: Path = Path("corrections.json"), - # data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), - # corrections_path: Path = Path("./data/valiation_data/corrections.json"), + ui_dump_file: Path = Path("ui_dump.json"), skip_incorrect: bool = True, ): - data_manifest_path = manifest_dir / Path(data_name) / manifest_file + data_manifest_path = dump_dir / Path(data_name) / manifest_file corrections_path = dump_dir / Path(data_name) / corrections_file + ui_dump_path = dump_dir / Path(data_name) / ui_dump_file - def correct_manifest(manifest_data_gen, corrections_path): - corrections = json.load(corrections_path.open()) + def correct_manifest(ui_dump_path, corrections_path): + corrections = ExtendedPath(corrections_path).read_json() + ui_data = ExtendedPath(ui_dump_path).read_json()['data'] correct_set = { c["code"] for c in corrections if c["value"]["status"] == "Correct" } @@ -355,36 +341,38 @@ def update_corrections( # for d in manifest_data_gen: # if d["chars"] in incorrect_set: # d["audio_path"].unlink() - renamed_set = set() - for d in manifest_data_gen: - if d["chars"] in correct_set: + # renamed_set = set() + for d in ui_data: + if d["utterance_id"] in correct_set: yield { "audio_filepath": d["audio_filepath"], "duration": d["duration"], "text": d["text"], } - elif d["chars"] in correction_map: - correct_text = correction_map[d["chars"]] + elif d["utterance_id"] in correction_map: + correct_text = correction_map[d["utterance_id"]] if skip_incorrect: print( f'skipping incorrect {d["audio_path"]} corrected to {correct_text}' ) else: - renamed_set.add(correct_text) - new_name = str(Path(correct_text).with_suffix(".wav")) - d["audio_path"].replace(d["audio_path"].with_name(new_name)) + orig_audio_path = Path(d["audio_path"]) + new_name = str(Path(tscript_uuid_fname(correct_text)).with_suffix(".wav")) + new_audio_path = orig_audio_path.with_name(new_name) + orig_audio_path.replace(new_audio_path) new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) yield { "audio_filepath": new_filepath, "duration": d["duration"], - "text": alnum_to_asr_tokens(correct_text), + "text": correct_text, } else: + orig_audio_path = Path(d["audio_path"]) # don't delete if another correction points to an old file - if d["chars"] not in renamed_set: - d["audio_path"].unlink() - else: - print(f'skipping deletion of correction:{d["chars"]}') + # if d["text"] not in renamed_set: + orig_audio_path.unlink() + # else: + # print(f'skipping deletion of correction:{d["text"]}') typer.echo(f"Using data manifest:{data_manifest_path}") dataset_dir = data_manifest_path.parent @@ -393,8 +381,8 @@ def update_corrections( if not backup_dir.exists(): typer.echo(f"backing up to :{backup_dir}") shutil.copytree(str(dataset_dir), str(backup_dir)) - manifest_gen = asr_manifest_reader(data_manifest_path) - corrected_manifest = correct_manifest(manifest_gen, corrections_path) + # manifest_gen = asr_manifest_reader(data_manifest_path) + corrected_manifest = correct_manifest(ui_dump_path, corrections_path) new_data_manifest_path = data_manifest_path.with_name("manifest.new") asr_manifest_writer(new_data_manifest_path, corrected_manifest) new_data_manifest_path.replace(data_manifest_path) diff --git a/jasper/training/cli.py b/jasper/training/cli.py index 1f628bd..7ef9beb 100644 --- a/jasper/training/cli.py +++ b/jasper/training/cli.py @@ -41,7 +41,7 @@ def parse_args(): work_dir="./train/work", num_epochs=300, weight_decay=0.005, - checkpoint_save_freq=200, + checkpoint_save_freq=100, eval_freq=100, load_dir="./train/models/jasper/", warmup_steps=3, @@ -266,6 +266,7 @@ def create_all_dags(args, neural_factory): folder=neural_factory.checkpoint_dir, load_from_folder=args.load_dir, step_freq=args.checkpoint_save_freq, + checkpoints_to_keep=30, ) callbacks = [train_callback, chpt_callback]