1. implement dataset augmentation and validation in process

2. added option to skip 'incorrect' annotations in validation data
3. added confirmation on clearing mongo collection
4. added an option to navigate to a given text in the validation ui
5. added a dataset and remote option to trainer to load dataset from directory and remote rpyc service
Malar Kannan 2020-05-20 11:16:22 +05:30
parent 83db445a6f
commit 8e79bbb571
4 changed files with 85 additions and 58 deletions

View File

@ -2,20 +2,13 @@ import json
from pathlib import Path from pathlib import Path
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from .utils import asr_manifest_reader, asr_manifest_writer from .utils import asr_manifest_reader, asr_manifest_writer
from typing import List
from itertools import chain
import typer import typer
app = typer.Typer() app = typer.Typer()
@app.command()
def split_data(dataset_path: Path, test_size: float = 0.1):
manifest_path = dataset_path / Path("abs_manifest.json")
asr_data = list(asr_manifest_reader(manifest_path))
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
@app.command() @app.command()
def fixate_data(dataset_path: Path): def fixate_data(dataset_path: Path):
manifest_path = dataset_path / Path("manifest.json") manifest_path = dataset_path / Path("manifest.json")
@ -30,20 +23,31 @@ def fixate_data(dataset_path: Path):
@app.command() @app.command()
def augment_an4(): def augment_datasets(src_dataset_paths: List[Path], dest_dataset_path: Path):
an4_train = Path("/dataset/asr_data/an4/train_manifest.json").read_bytes() reader_list = []
an4_test = Path("/dataset/asr_data/an4/test_manifest.json").read_bytes() abs_manifest_path = Path("abs_manifest.json")
pnr_train = Path("/dataset/asr_data/pnr_data/train_manifest.json").read_bytes() for dataset_path in src_dataset_paths:
pnr_test = Path("/dataset/asr_data/pnr_data/test_manifest.json").read_bytes() manifest_path = dataset_path / abs_manifest_path
reader_list.append(asr_manifest_reader(manifest_path))
with Path("/dataset/asr_data/an4_pnr/train_manifest.json").open("wb") as pf: dest_dataset_path.mkdir(parents=True, exist_ok=True)
pf.write(an4_train + pnr_train) dest_manifest_path = dest_dataset_path / abs_manifest_path
with Path("/dataset/asr_data/an4_pnr/test_manifest.json").open("wb") as pf: asr_manifest_writer(dest_manifest_path, chain(*reader_list))
pf.write(an4_test + pnr_test)
@app.command() @app.command()
def validate_data(data_file: Path): def split_data(dataset_path: Path, test_size: float = 0.1):
manifest_path = dataset_path / Path("abs_manifest.json")
asr_data = list(asr_manifest_reader(manifest_path))
train_pnr, test_pnr = train_test_split(asr_data, test_size=test_size)
asr_manifest_writer(manifest_path.with_name("train_manifest.json"), train_pnr)
asr_manifest_writer(manifest_path.with_name("test_manifest.json"), test_pnr)
@app.command()
def validate_data(dataset_path: Path):
for mf_type in ["train_manifest.json", "test_manifest.json"]:
data_file = dataset_path / Path(mf_type)
print(f"validating {data_file}.")
with Path(data_file).open("r") as pf: with Path(data_file).open("r") as pf:
pnr_jsonl = pf.readlines() pnr_jsonl = pf.readlines()
for (i, s) in enumerate(pnr_jsonl): for (i, s) in enumerate(pnr_jsonl):
@ -54,7 +58,7 @@ def validate_data(data_file: Path):
raise OSError(f"File {audio_file} not found") raise OSError(f"File {audio_file} not found")
except BaseException as e: except BaseException as e:
print(f'failed on {i} with "{e}"') print(f'failed on {i} with "{e}"')
print("no errors found. seems like a valid manifest.") print(f"no errors found. seems like a valid {mf_type}.")
def main(): def main():

View File

@ -143,8 +143,8 @@ def fill_unannotated(
@app.command() @app.command()
def update_corrections( def update_corrections(
data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"), data_manifest_path: Path = Path("./data/asr_data/call_alphanum/manifest.json"),
processed_data_path: Path = Path("./data/valiation_data/ui_dump.json"),
corrections_path: Path = Path("./data/valiation_data/corrections.json"), corrections_path: Path = Path("./data/valiation_data/corrections.json"),
skip_incorrect: bool = True,
): ):
def correct_manifest(manifest_data_gen, corrections_path): def correct_manifest(manifest_data_gen, corrections_path):
corrections = json.load(corrections_path.open()) corrections = json.load(corrections_path.open())
@ -170,6 +170,9 @@ def update_corrections(
} }
elif d["chars"] in correction_map: elif d["chars"] in correction_map:
correct_text = correction_map[d["chars"]] correct_text = correction_map[d["chars"]]
if skip_incorrect:
print(f'skipping incorrect {d["audio_path"]} corrected to {correct_text}')
else:
renamed_set.add(correct_text) renamed_set.add(correct_text)
new_name = str(Path(correct_text).with_suffix(".wav")) new_name = str(Path(correct_text).with_suffix(".wav"))
d["audio_path"].replace(d["audio_path"].with_name(new_name)) d["audio_path"].replace(d["audio_path"].with_name(new_name))
@ -202,8 +205,12 @@ def update_corrections(
@app.command() @app.command()
def clear_mongo_corrections(): def clear_mongo_corrections():
delete = typer.confirm("are you sure you want to clear mongo collection it?")
if delete:
col = get_mongo_conn().test.asr_validation col = get_mongo_conn().test.asr_validation
col.delete_many({"type": "correction"}) col.delete_many({"type": "correction"})
typer.echo("deleted mongo collection.")
typer.echo("Aborted")
def main(): def main():

View File

@ -53,7 +53,7 @@ if not hasattr(st, "mongo_connected"):
@st.cache() @st.cache()
def load_ui_data(validation_ui_data_path: Path): def load_ui_data(validation_ui_data_path: Path):
typer.echo(f"Using validation ui data from :{validation_ui_data_path}") typer.echo(f"Using validation ui data from {validation_ui_data_path}")
return ExtendedPath(validation_ui_data_path).read_json() return ExtendedPath(validation_ui_data_path).read_json()
@ -117,6 +117,11 @@ def main(manifest: Path):
# pass # pass
# if st.button("Next Untagged"): # if st.button("Next Untagged"):
# pass # pass
text_sample = st.text_input("Go to Text:", value='')
if text_sample != '':
candidates = [i for (i, p) in enumerate(asr_data) if p["text"] == text_sample or p["spoken"] == text_sample]
if len(candidates) > 0:
st.update_cursor(candidates[0])
real_idx = st.number_input( real_idx = st.number_input(
"Go to real-index", "Go to real-index",
value=sample["real_idx"], value=sample["real_idx"],

View File

@ -3,6 +3,7 @@ import argparse
import copy import copy
import math import math
import os import os
from pathlib import Path
from functools import partial from functools import partial
from ruamel.yaml import YAML from ruamel.yaml import YAML
@ -36,13 +37,13 @@ def parse_args():
lr=0.002, lr=0.002,
amp_opt_level="O1", amp_opt_level="O1",
create_tb_writer=True, create_tb_writer=True,
model_config="./train/jasper10x5dr.yaml", model_config="./train/jasper-speller10x5dr.yaml",
train_dataset="./train/asr_data/train_manifest.json", # train_dataset="./train/asr_data/train_manifest.json",
eval_datasets="./train/asr_data/test_manifest.json", # eval_datasets="./train/asr_data/test_manifest.json",
work_dir="./train/work", work_dir="./train/work",
num_epochs=50, num_epochs=300,
weight_decay=0.005, weight_decay=0.005,
checkpoint_save_freq=1000, checkpoint_save_freq=200,
eval_freq=100, eval_freq=100,
load_dir="./train/models/jasper/", load_dir="./train/models/jasper/",
warmup_steps=3, warmup_steps=3,
@ -70,7 +71,6 @@ def parse_args():
required=False, required=False,
help="model configuration file: model.yaml", help="model configuration file: model.yaml",
) )
parser.add_argument( parser.add_argument(
"--remote_data", "--remote_data",
type=str, type=str,
@ -78,6 +78,13 @@ def parse_args():
default="", default="",
help="remote dataloader endpoint", help="remote dataloader endpoint",
) )
parser.add_argument(
"--dataset",
type=str,
required=False,
default="",
help="dataset directory containing train/test manifests",
)
# Create new args # Create new args
parser.add_argument("--exp_name", default="Jasper", type=str) parser.add_argument("--exp_name", default="Jasper", type=str)
@ -120,22 +127,26 @@ def create_all_dags(args, neural_factory):
# Calculate num_workers for dataloader # Calculate num_workers for dataloader
total_cpus = os.cpu_count() total_cpus = os.cpu_count()
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1) cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
# cpu_per_traindl = 1
# perturb_config = jasper_params.get('perturb', None) # perturb_config = jasper_params.get('perturb', None)
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"]) train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
del train_dl_params["train"] del train_dl_params["train"]
del train_dl_params["eval"] del train_dl_params["eval"]
# del train_dl_params["normalize_transcripts"] # del train_dl_params["normalize_transcripts"]
if args.dataset:
d_path = Path(args.dataset)
if not args.train_dataset:
args.train_dataset = str(d_path / Path("train_manifest.json"))
if not args.eval_datasets:
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
data_loader_layer = nemo_asr.AudioToTextDataLayer data_loader_layer = nemo_asr.AudioToTextDataLayer
if args.remote_data: if args.remote_data:
train_dl_params['rpyc_host'] = args.remote_data train_dl_params["rpyc_host"] = args.remote_data
data_loader_layer = RpycAudioToTextDataLayer data_loader_layer = RpycAudioToTextDataLayer
# if args.remote_data:
# # import pdb; pdb.set_trace()
# data_loader_layer = rpyc.connect(
# args.remote_data, 8064, config={"sync_request_timeout": 600}
# ).root.get_data_loader()
data_layer = data_loader_layer( data_layer = data_loader_layer(
manifest_filepath=args.train_dataset, manifest_filepath=args.train_dataset,
sample_rate=sample_rate, sample_rate=sample_rate,
@ -169,7 +180,7 @@ def create_all_dags(args, neural_factory):
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"]) eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
if args.remote_data: if args.remote_data:
eval_dl_params['rpyc_host'] = args.remote_data eval_dl_params["rpyc_host"] = args.remote_data
del eval_dl_params["train"] del eval_dl_params["train"]
del eval_dl_params["eval"] del eval_dl_params["eval"]
data_layers_eval = [] data_layers_eval = []