From 8e79bbb57139a7f937d5175a38a35432abed0cf1 Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Wed, 20 May 2020 11:16:22 +0530 Subject: [PATCH] 1. implement dataset augmentation and validation in process 2. added option to skip 'incorrect' annotations in validation data 3. added confirmation on clearing mongo collection 4. added an option to navigate to a given text in the validation ui 5. added a dataset and remote option to trainer to load dataset from directory and remote rpyc service --- jasper/data_utils/process.py | 66 +++++++++++++------------ jasper/data_utils/validation/process.py | 31 +++++++----- jasper/data_utils/validation/ui.py | 7 ++- jasper/training_utils/train.py | 39 +++++++++------ 4 files changed, 85 insertions(+), 58 deletions(-) diff --git a/jasper/data_utils/process.py b/jasper/data_utils/process.py index 65e66bf..000e843 100644 --- a/jasper/data_utils/process.py +++ b/jasper/data_utils/process.py @@ -2,20 +2,13 @@ import json from pathlib import Path from sklearn.model_selection import train_test_split from .utils import asr_manifest_reader, asr_manifest_writer +from typing import List +from itertools import chain import typer app = typer.Typer() -@app.command() -def split_data(dataset_path: Path, test_size: float = 0.1): - manifest_path = dataset_path / Path("abs_manifest.json") - asr_data = list(asr_manifest_reader(manifest_path)) - train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size) - asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr) - asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr) - - @app.command() def fixate_data(dataset_path: Path): manifest_path = dataset_path / Path("manifest.json") @@ -30,31 +23,42 @@ def fixate_data(dataset_path: Path): @app.command() -def augment_an4(): - an4_train = Path("/dataset/asr_data/an4/train_manifest.json").read_bytes() - an4_test = Path("/dataset/asr_data/an4/test_manifest.json").read_bytes() - pnr_train = Path("/dataset/asr_data/pnr_data/train_manifest.json").read_bytes() - pnr_test = Path("/dataset/asr_data/pnr_data/test_manifest.json").read_bytes() - - with Path("/dataset/asr_data/an4_pnr/train_manifest.json").open("wb") as pf: - pf.write(an4_train + pnr_train) - with Path("/dataset/asr_data/an4_pnr/test_manifest.json").open("wb") as pf: - pf.write(an4_test + pnr_test) +def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path): + reader_list = [] + abs_manifest_path = Path("abs_manifest.json") + for dataset_path in src_dataset_paths: + manifest_path = dataset_path / abs_manifest_path + reader_list.append(asr_manifest_reader(manifest_path)) + dest_dataset_path.mkdir(parents=True, exist_ok=True) + dest_manifest_path = dest_dataset_path / abs_manifest_path + asr_manifest_writer(dest_manifest_path, chain(*reader_list)) @app.command() -def validate_data(data_file: Path): - with Path(data_file).open("r") as pf: - pnr_jsonl = pf.readlines() - for (i, s) in enumerate(pnr_jsonl): - try: - d = json.loads(s) - audio_file = data_file.parent / Path(d["audio_filepath"]) - if not audio_file.exists(): - raise OSError(f"File {audio_file} not found") - except BaseException as e: - print(f'failed on {i} with "{e}"') - print("no errors found. seems like a valid manifest.") +def split_data(dataset_path: Path, test_size: float = 0.1): + manifest_path = dataset_path / Path("abs_manifest.json") + asr_data = list(asr_manifest_reader(manifest_path)) + train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size) + asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr) + asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr) + + +@app.command() +def validate_data(dataset_path: Path): + for mf_type in ["train_manifest.json", "test_manifest.json"]: + data_file = dataset_path / Path(mf_type) + print(f"validating {data_file}.") + with Path(data_file).open("r") as pf: + pnr_jsonl = pf.readlines() + for (i, s) in enumerate(pnr_jsonl): + try: + d = json.loads(s) + audio_file = data_file.parent / Path(d["audio_filepath"]) + if not audio_file.exists(): + raise OSError(f"File {audio_file} not found") + except BaseException as e: + print(f'failed on {i} with "{e}"') + print(f"no errors found. seems like a valid {mf_type}.") def main(): diff --git a/jasper/data_utils/validation/process.py b/jasper/data_utils/validation/process.py index cdd1cbc..2dc3daa 100644 --- a/jasper/data_utils/validation/process.py +++ b/jasper/data_utils/validation/process.py @@ -143,8 +143,8 @@ def fill_unannotated( @app.command() def update_corrections( data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), - processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"), corrections_path: Path = Path("./data/valiation_data/corrections.json"), + skip_incorrect: bool = True, ): def correct_manifest(manifest_data_gen, corrections_path): corrections = json.load(corrections_path.open()) @@ -170,15 +170,18 @@ def update_corrections( } elif d["chars"] in correction_map: correct_text = correction_map[d["chars"]] - renamed_set.add(correct_text) - new_name = str(Path(correct_text).with_suffix(".wav")) - d["audio_path"].replace(d["audio_path"].with_name(new_name)) - new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) - yield { - "audio_filepath": new_filepath, - "duration": d["duration"], - "text": alnum_to_asr_tokens(correct_text), - } + if skip_incorrect: + print(f'skipping incorrect {d["audio_path"]} corrected to {correct_text}') + else: + renamed_set.add(correct_text) + new_name = str(Path(correct_text).with_suffix(".wav")) + d["audio_path"].replace(d["audio_path"].with_name(new_name)) + new_filepath = str(Path(d["audio_filepath"]).with_name(new_name)) + yield { + "audio_filepath": new_filepath, + "duration": d["duration"], + "text": alnum_to_asr_tokens(correct_text), + } else: # don't delete if another correction points to an old file if d["chars"] not in renamed_set: @@ -202,8 +205,12 @@ def update_corrections( @app.command() def clear_mongo_corrections(): - col = get_mongo_conn().test.asr_validation - col.delete_many({"type": "correction"}) + delete = typer.confirm("are you sure you want to clear mongo collection it?") + if delete: + col = get_mongo_conn().test.asr_validation + col.delete_many({"type": "correction"}) + typer.echo("deleted mongo collection.") + typer.echo("Aborted") def main(): diff --git a/jasper/data_utils/validation/ui.py b/jasper/data_utils/validation/ui.py index 7467b0e..04937f1 100644 --- a/jasper/data_utils/validation/ui.py +++ b/jasper/data_utils/validation/ui.py @@ -53,7 +53,7 @@ if not hasattr(st, "mongo_connected"): @st.cache() def load_ui_data(validation_ui_data_path: Path): - typer.echo(f"Using validation ui data from :{validation_ui_data_path}") + typer.echo(f"Using validation ui data from {validation_ui_data_path}") return ExtendedPath(validation_ui_data_path).read_json() @@ -117,6 +117,11 @@ def main(manifest: Path): # pass # if st.button("Next Untagged"): # pass + text_sample = st.text_input("Go to Text:", value='') + if text_sample != '': + candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample] + if len(candidates) > 0: + st.update_cursor(candidates[0]) real_idx = st.number_input( "Go to real-index", value=sample["real_idx"], diff --git a/jasper/training_utils/train.py b/jasper/training_utils/train.py index 4b4c97f..15f13fc 100644 --- a/jasper/training_utils/train.py +++ b/jasper/training_utils/train.py @@ -3,6 +3,7 @@ import argparse import copy import math import os +from pathlib import Path from functools import partial from ruamel.yaml import YAML @@ -36,13 +37,13 @@ def parse_args(): lr=0.002, amp_opt_level="O1", create_tb_writer=True, - model_config="./train/jasper10x5dr.yaml", - train_dataset="./train/asr_data/train_manifest.json", - eval_datasets="./train/asr_data/test_manifest.json", + model_config="./train/jasper-speller10x5dr.yaml", + # train_dataset="./train/asr_data/train_manifest.json", + # eval_datasets="./train/asr_data/test_manifest.json", work_dir="./train/work", - num_epochs=50, + num_epochs=300, weight_decay=0.005, - checkpoint_save_freq=1000, + checkpoint_save_freq=200, eval_freq=100, load_dir="./train/models/jasper/", warmup_steps=3, @@ -70,7 +71,6 @@ def parse_args(): required=False, help="model configuration file: model.yaml", ) - parser.add_argument( "--remote_data", type=str, @@ -78,6 +78,13 @@ def parse_args(): default="", help="remote dataloader endpoint", ) + parser.add_argument( + "--dataset", + type=str, + required=False, + default="", + help="dataset directory containing train/test manifests", + ) # Create new args parser.add_argument("--exp_name", default="Jasper", type=str) @@ -120,22 +127,26 @@ def create_all_dags(args, neural_factory): # Calculate num_workers for dataloader total_cpus = os.cpu_count() cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1) - # cpu_per_traindl = 1 # perturb_config = jasper_params.get('perturb', None) train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"]) del train_dl_params["train"] del train_dl_params["eval"] # del train_dl_params["normalize_transcripts"] + + if args.dataset: + d_path = Path(args.dataset) + if not args.train_dataset: + args.train_dataset = str(d_path / Path("train_manifest.json")) + if not args.eval_datasets: + args.eval_datasets = [str(d_path / Path("test_manifest.json"))] + data_loader_layer = nemo_asr.AudioToTextDataLayer + if args.remote_data: - train_dl_params['rpyc_host'] = args.remote_data + train_dl_params["rpyc_host"] = args.remote_data data_loader_layer = RpycAudioToTextDataLayer - # if args.remote_data: - # # import pdb; pdb.set_trace() - # data_loader_layer = rpyc.connect( - # args.remote_data, 8064, config={"sync_request_timeout": 600} - # ).root.get_data_loader() + data_layer = data_loader_layer( manifest_filepath=args.train_dataset, sample_rate=sample_rate, @@ -169,7 +180,7 @@ def create_all_dags(args, neural_factory): eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"]) if args.remote_data: - eval_dl_params['rpyc_host'] = args.remote_data + eval_dl_params["rpyc_host"] = args.remote_data del eval_dl_params["train"] del eval_dl_params["eval"] data_layers_eval = []