diff --git a/jasper/data_utils/process.py b/jasper/data_utils/process.py index 65e66bf..000e843 100644 --- a/jasper/data_utils/process.py +++ b/jasper/data_utils/process.py @@ -2,20 +2,13 @@ import json from pathlib import Path from sklearn.model_selection import train_test_split from .utils import asr_manifest_reader, asr_manifest_writer +from typing import List +from itertools import chain import typer app = typer.Typer() -@app.command() -def split_data(dataset_path: Path, test_size: float = 0.1): - manifest_path = dataset_path / Path("abs_manifest.json") - asr_data = list(asr_manifest_reader(manifest_path)) - train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size) - asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr) - asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr) - - @app.command() def fixate_data(dataset_path: Path): manifest_path = dataset_path / Path("manifest.json") @@ -30,31 +23,42 @@ def fixate_data(dataset_path: Path): @app.command() -def augment_an4(): - an4_train = Path("/dataset/asr_data/an4/train_manifest.json").read_bytes() - an4_test = Path("/dataset/asr_data/an4/test_manifest.json").read_bytes() - pnr_train = Path("/dataset/asr_data/pnr_data/train_manifest.json").read_bytes() - pnr_test = Path("/dataset/asr_data/pnr_data/test_manifest.json").read_bytes() - - with Path("/dataset/asr_data/an4_pnr/train_manifest.json").open("wb") as pf: - pf.write(an4_train + pnr_train) - with Path("/dataset/asr_data/an4_pnr/test_manifest.json").open("wb") as pf: - pf.write(an4_test + pnr_test) +def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path): + reader_list = [] + abs_manifest_path = Path("abs_manifest.json") + for dataset_path in src_dataset_paths: + manifest_path = dataset_path / abs_manifest_path + reader_list.append(asr_manifest_reader(manifest_path)) + dest_dataset_path.mkdir(parents=True, exist_ok=True) + dest_manifest_path = dest_dataset_path / abs_manifest_path + asr_manifest_writer(dest_manifest_path, chain(*reader_list)) @app.command() -def validate_data(data_file: Path): - with Path(data_file).open("r") as pf: - pnr_jsonl = pf.readlines() - for (i, s) in enumerate(pnr_jsonl): - try: - d = json.loads(s) - audio_file = data_file.parent / Path(d["audio_filepath"]) - if not audio_file.exists(): - raise OSError(f"File {audio_file} not found") - except BaseException as e: - print(f'failed on {i} with "{e}"') - print("no errors found. seems like a valid manifest.") +def split_data(dataset_path: Path, test_size: float = 0.1): + manifest_path = dataset_path / Path("abs_manifest.json") + asr_data = list(asr_manifest_reader(manifest_path)) + train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size) + asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr) + asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr) + + +@app.command() +def validate_data(dataset_path: Path): + for mf_type in ["train_manifest.json", "test_manifest.json"]: + data_file = dataset_path / Path(mf_type) + print(f"validating {data_file}.") + with Path(data_file).open("r") as pf: + pnr_jsonl = pf.readlines() + for (i, s) in enumerate(pnr_jsonl): + try: + d = json.loads(s) + audio_file = data_file.parent / Path(d["audio_filepath"]) + if not audio_file.exists(): + raise OSError(f"File {audio_file} not found") + except BaseException as e: + print(f'failed on {i} with "{e}"') + print(f"no errors found. seems like a valid {mf_type}.") def main(): diff --git a/jasper/data_utils/validation/process.py b/jasper/data_utils/validation/process.py index cdd1cbc..2dc3daa 100644 --- a/jasper/data_utils/validation/process.py +++ b/jasper/data_utils/validation/process.py @@ -143,8 +143,8 @@ def fill_unannotated( @app.command() def update_corrections( data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), - processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"), corrections_path: Path = Path("./data/valiation_data/corrections.json"), + skip_incorrect: bool = True, ): def correct_manifest(manifest_data_gen, corrections_path): corrections = json.load(corrections_path.open()) @@ -170,15 +170,18 @@ def update_corrections( } elif d["chars"] in correction_map: correct_text = correction_map[d["chars"]] - renamed_set.add(correct_text) - new_name = str(Path(correct_text).with_suffix(".wav")) - d["audio_path"].replace(d["audio_path"].with_name(new_name)) - new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) - yield { - "audio_filepath": new_filepath, - "duration": d["duration"], - "text": alnum_to_asr_tokens(correct_text), - } + if skip_incorrect: + print(f'skipping incorrect {d["audio_path"]} corrected to {correct_text}') + else: + renamed_set.add(correct_text) + new_name = str(Path(correct_text).with_suffix(".wav")) + d["audio_path"].replace(d["audio_path"].with_name(new_name)) + new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) + yield { + "audio_filepath": new_filepath, + "duration": d["duration"], + "text": alnum_to_asr_tokens(correct_text), + } else: # don't delete if another correction points to an old file if d["chars"] not in renamed_set: @@ -202,8 +205,12 @@ def update_corrections( @app.command() def clear_mongo_corrections(): - col = get_mongo_conn().test.asr_validation - col.delete_many({"type": "correction"}) + delete = typer.confirm("are you sure you want to clear mongo collection it?") + if delete: + col = get_mongo_conn().test.asr_validation + col.delete_many({"type": "correction"}) + typer.echo("deleted mongo collection.") + typer.echo("Aborted") def main(): diff --git a/jasper/data_utils/validation/ui.py b/jasper/data_utils/validation/ui.py index 7467b0e..04937f1 100644 --- a/jasper/data_utils/validation/ui.py +++ b/jasper/data_utils/validation/ui.py @@ -53,7 +53,7 @@ if not hasattr(st, "mongo_connected"): @st.cache() def load_ui_data(validation_ui_data_path: Path): - typer.echo(f"Using validation ui data from :{validation_ui_data_path}") + typer.echo(f"Using validation ui data from {validation_ui_data_path}") return ExtendedPath(validation_ui_data_path).read_json() @@ -117,6 +117,11 @@ def main(manifest: Path): # pass # if st.button("Next Untagged"): # pass + text_sample = st.text_input("Go to Text:", value='') + if text_sample != '': + candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample] + if len(candidates) > 0: + st.update_cursor(candidates[0]) real_idx = st.number_input( "Go to real-index", value=sample["real_idx"], diff --git a/jasper/training_utils/train.py b/jasper/training_utils/train.py index 4b4c97f..15f13fc 100644 --- a/jasper/training_utils/train.py +++ b/jasper/training_utils/train.py @@ -3,6 +3,7 @@ import argparse import copy import math import os +from pathlib import Path from functools import partial from ruamel.yaml import YAML @@ -36,13 +37,13 @@ def parse_args(): lr=0.002, amp_opt_level="O1", create_tb_writer=True, - model_config="./train/jasper10x5dr.yaml", - train_dataset="./train/asr_data/train_manifest.json", - eval_datasets="./train/asr_data/test_manifest.json", + model_config="./train/jasper-speller10x5dr.yaml", + # train_dataset="./train/asr_data/train_manifest.json", + # eval_datasets="./train/asr_data/test_manifest.json", work_dir="./train/work", - num_epochs=50, + num_epochs=300, weight_decay=0.005, - checkpoint_save_freq=1000, + checkpoint_save_freq=200, eval_freq=100, load_dir="./train/models/jasper/", warmup_steps=3, @@ -70,7 +71,6 @@ def parse_args(): required=False, help="model configuration file: model.yaml", ) - parser.add_argument( "--remote_data", type=str, @@ -78,6 +78,13 @@ def parse_args(): default="", help="remote dataloader endpoint", ) + parser.add_argument( + "--dataset", + type=str, + required=False, + default="", + help="dataset directory containing train/test manifests", + ) # Create new args parser.add_argument("--exp_name", default="Jasper", type=str) @@ -120,22 +127,26 @@ def create_all_dags(args, neural_factory): # Calculate num_workers for dataloader total_cpus = os.cpu_count() cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1) - # cpu_per_traindl = 1 # perturb_config = jasper_params.get('perturb', None) train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"]) del train_dl_params["train"] del train_dl_params["eval"] # del train_dl_params["normalize_transcripts"] + + if args.dataset: + d_path = Path(args.dataset) + if not args.train_dataset: + args.train_dataset = str(d_path / Path("train_manifest.json")) + if not args.eval_datasets: + args.eval_datasets = [str(d_path / Path("test_manifest.json"))] + data_loader_layer = nemo_asr.AudioToTextDataLayer + if args.remote_data: - train_dl_params['rpyc_host'] = args.remote_data + train_dl_params["rpyc_host"] = args.remote_data data_loader_layer = RpycAudioToTextDataLayer - # if args.remote_data: - # # import pdb; pdb.set_trace() - # data_loader_layer = rpyc.connect( - # args.remote_data, 8064, config={"sync_request_timeout": 600} - # ).root.get_data_loader() + data_layer = data_loader_layer( manifest_filepath=args.train_dataset, sample_rate=sample_rate, @@ -169,7 +180,7 @@ def create_all_dags(args, neural_factory): eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"]) if args.remote_data: - eval_dl_params['rpyc_host'] = args.remote_data + eval_dl_params["rpyc_host"] = args.remote_data del eval_dl_params["train"] del eval_dl_params["eval"] data_layers_eval = []