1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-03-09 19:02:35 +00:00

1. implement dataset augmentation and validation in process

2. added option to skip 'incorrect' annotations in validation data
3. added confirmation on clearing mongo collection
4. added an option to navigate to a given text in the validation ui
5. added a dataset and remote option to trainer to load dataset from directory and remote rpyc service
This commit is contained in:
2020-05-20 11:16:22 +05:30
parent 83db445a6f
commit 8e79bbb571
4 changed files with 85 additions and 58 deletions

View File

@@ -3,6 +3,7 @@ import argparse
import copy
import math
import os
from pathlib import Path
from functools import partial
from ruamel.yaml import YAML
@@ -36,13 +37,13 @@ def parse_args():
lr=0.002,
amp_opt_level="O1",
create_tb_writer=True,
model_config="./train/jasper10x5dr.yaml",
train_dataset="./train/asr_data/train_manifest.json",
eval_datasets="./train/asr_data/test_manifest.json",
model_config="./train/jasper-speller10x5dr.yaml",
# train_dataset="./train/asr_data/train_manifest.json",
# eval_datasets="./train/asr_data/test_manifest.json",
work_dir="./train/work",
num_epochs=50,
num_epochs=300,
weight_decay=0.005,
checkpoint_save_freq=1000,
checkpoint_save_freq=200,
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
@@ -70,7 +71,6 @@ def parse_args():
required=False,
help="model configuration file: model.yaml",
)
parser.add_argument(
"--remote_data",
type=str,
@@ -78,6 +78,13 @@ def parse_args():
default="",
help="remote dataloader endpoint",
)
parser.add_argument(
"--dataset",
type=str,
required=False,
default="",
help="dataset directory containing train/test manifests",
)
# Create new args
parser.add_argument("--exp_name", default="Jasper", type=str)
@@ -120,22 +127,26 @@ def create_all_dags(args, neural_factory):
# Calculate num_workers for dataloader
total_cpus = os.cpu_count()
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
# cpu_per_traindl = 1
# perturb_config = jasper_params.get('perturb', None)
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
del train_dl_params["train"]
del train_dl_params["eval"]
# del train_dl_params["normalize_transcripts"]
if args.dataset:
d_path = Path(args.dataset)
if not args.train_dataset:
args.train_dataset = str(d_path / Path("train_manifest.json"))
if not args.eval_datasets:
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
data_loader_layer = nemo_asr.AudioToTextDataLayer
if args.remote_data:
train_dl_params['rpyc_host'] = args.remote_data
train_dl_params["rpyc_host"] = args.remote_data
data_loader_layer = RpycAudioToTextDataLayer
# if args.remote_data:
# # import pdb; pdb.set_trace()
# data_loader_layer = rpyc.connect(
# args.remote_data, 8064, config={"sync_request_timeout": 600}
# ).root.get_data_loader()
data_layer = data_loader_layer(
manifest_filepath=args.train_dataset,
sample_rate=sample_rate,
@@ -169,7 +180,7 @@ def create_all_dags(args, neural_factory):
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
if args.remote_data:
eval_dl_params['rpyc_host'] = args.remote_data
eval_dl_params["rpyc_host"] = args.remote_data
del eval_dl_params["train"]
del eval_dl_params["eval"]
data_layers_eval = []