mirror of
https://github.com/malarinv/jasper-asr.git
synced 2026-03-09 19:02:35 +00:00
1. implement dataset augmentation and validation in process
2. added option to skip 'incorrect' annotations in validation data 3. added confirmation on clearing mongo collection 4. added an option to navigate to a given text in the validation ui 5. added a dataset and remote option to trainer to load dataset from directory and remote rpyc service
This commit is contained in:
@@ -3,6 +3,7 @@ import argparse
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
@@ -36,13 +37,13 @@ def parse_args():
|
||||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
train_dataset="./train/asr_data/train_manifest.json",
|
||||
eval_datasets="./train/asr_data/test_manifest.json",
|
||||
model_config="./train/jasper-speller10x5dr.yaml",
|
||||
# train_dataset="./train/asr_data/train_manifest.json",
|
||||
# eval_datasets="./train/asr_data/test_manifest.json",
|
||||
work_dir="./train/work",
|
||||
num_epochs=50,
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
checkpoint_save_freq=1000,
|
||||
checkpoint_save_freq=200,
|
||||
eval_freq=100,
|
||||
load_dir="./train/models/jasper/",
|
||||
warmup_steps=3,
|
||||
@@ -70,7 +71,6 @@ def parse_args():
|
||||
required=False,
|
||||
help="model configuration file: model.yaml",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--remote_data",
|
||||
type=str,
|
||||
@@ -78,6 +78,13 @@ def parse_args():
|
||||
default="",
|
||||
help="remote dataloader endpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="",
|
||||
help="dataset directory containing train/test manifests",
|
||||
)
|
||||
|
||||
# Create new args
|
||||
parser.add_argument("--exp_name", default="Jasper", type=str)
|
||||
@@ -120,22 +127,26 @@ def create_all_dags(args, neural_factory):
|
||||
# Calculate num_workers for dataloader
|
||||
total_cpus = os.cpu_count()
|
||||
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
|
||||
# cpu_per_traindl = 1
|
||||
# perturb_config = jasper_params.get('perturb', None)
|
||||
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
|
||||
del train_dl_params["train"]
|
||||
del train_dl_params["eval"]
|
||||
# del train_dl_params["normalize_transcripts"]
|
||||
|
||||
if args.dataset:
|
||||
d_path = Path(args.dataset)
|
||||
if not args.train_dataset:
|
||||
args.train_dataset = str(d_path / Path("train_manifest.json"))
|
||||
if not args.eval_datasets:
|
||||
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
|
||||
|
||||
data_loader_layer = nemo_asr.AudioToTextDataLayer
|
||||
|
||||
if args.remote_data:
|
||||
train_dl_params['rpyc_host'] = args.remote_data
|
||||
train_dl_params["rpyc_host"] = args.remote_data
|
||||
data_loader_layer = RpycAudioToTextDataLayer
|
||||
# if args.remote_data:
|
||||
# # import pdb; pdb.set_trace()
|
||||
# data_loader_layer = rpyc.connect(
|
||||
# args.remote_data, 8064, config={"sync_request_timeout": 600}
|
||||
# ).root.get_data_loader()
|
||||
|
||||
data_layer = data_loader_layer(
|
||||
manifest_filepath=args.train_dataset,
|
||||
sample_rate=sample_rate,
|
||||
@@ -169,7 +180,7 @@ def create_all_dags(args, neural_factory):
|
||||
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
|
||||
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
|
||||
if args.remote_data:
|
||||
eval_dl_params['rpyc_host'] = args.remote_data
|
||||
eval_dl_params["rpyc_host"] = args.remote_data
|
||||
del eval_dl_params["train"]
|
||||
del eval_dl_params["eval"]
|
||||
data_layers_eval = []
|
||||
|
||||
Reference in New Issue
Block a user