# Copyright (c) 2019 NVIDIA Corporation import argparse import copy # import math import os from pathlib import Path from functools import partial from ruamel.yaml import YAML import nemo import nemo.collections.asr as nemo_asr import nemo.utils.argparse as nm_argparse from nemo.collections.asr.helpers import ( # monitor_asr_train_progress, process_evaluation_batch, process_evaluation_epoch, ) # from nemo.utils.lr_policies import CosineAnnealing from training.data_loaders import RpycAudioToTextDataLayer logging = nemo.logging def parse_args(): parser = argparse.ArgumentParser( parents=[nm_argparse.NemoArgParser()], description="Jasper", conflict_handler="resolve", ) parser.set_defaults( checkpoint_dir=None, optimizer="novograd", batch_size=64, eval_batch_size=64, lr=0.002, amp_opt_level="O1", create_tb_writer=True, model_config="./train/jasper10x5dr.yaml", work_dir="./train/work", num_epochs=300, weight_decay=0.005, checkpoint_save_freq=100, eval_freq=100, load_dir="./train/models/jasper/", warmup_steps=3, exp_name="jasper", ) # Overwrite default args parser.add_argument( "--max_steps", type=int, default=None, required=False, help="max number of steps to train", ) parser.add_argument( "--num_epochs", type=int, required=False, help="number of epochs to train" ) parser.add_argument( "--model_config", type=str, required=False, help="model configuration file: model.yaml", ) parser.add_argument( "--encoder_checkpoint", type=str, required=True, help="encoder checkpoint file: JasperEncoder.pt", ) parser.add_argument( "--decoder_checkpoint", type=str, required=True, help="decoder checkpoint file: JasperDecoderForCTC.pt", ) parser.add_argument( "--remote_data", type=str, required=False, default="", help="remote dataloader endpoint", ) parser.add_argument( "--dataset", type=str, required=False, default="", help="dataset directory containing train/test manifests", ) # Create new args parser.add_argument("--exp_name", default="Jasper", type=str) parser.add_argument("--beta1", default=0.95, type=float) parser.add_argument("--beta2", default=0.25, type=float) parser.add_argument("--warmup_steps", default=0, type=int) parser.add_argument( "--load_dir", default=None, type=str, help="directory with pre-trained checkpoint", ) args = parser.parse_args() if args.max_steps is None and args.num_epochs is None: raise ValueError("Either max_steps or num_epochs should be provided.") return args def construct_name( name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step ): if max_steps is not None: return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format( name, lr, batch_size, max_steps, wd, optimizer, iter_per_step ) else: return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format( name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step ) def create_all_dags(args, neural_factory): yaml = YAML(typ="safe") with open(args.model_config) as f: jasper_params = yaml.load(f) vocab = jasper_params["labels"] sample_rate = jasper_params["sample_rate"] # Calculate num_workers for dataloader total_cpus = os.cpu_count() cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1) # perturb_config = jasper_params.get('perturb', None) train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"]) del train_dl_params["train"] del train_dl_params["eval"] # del train_dl_params["normalize_transcripts"] if args.dataset: d_path = Path(args.dataset) if not args.train_dataset: args.train_dataset = str(d_path / Path("train_manifest.json")) if not args.eval_datasets: args.eval_datasets = [str(d_path / Path("test_manifest.json"))] data_loader_layer = nemo_asr.AudioToTextDataLayer if args.remote_data: train_dl_params["rpyc_host"] = args.remote_data data_loader_layer = RpycAudioToTextDataLayer # data_layer = data_loader_layer( # manifest_filepath=args.train_dataset, # sample_rate=sample_rate, # labels=vocab, # batch_size=args.batch_size, # num_workers=cpu_per_traindl, # **train_dl_params, # # normalize_transcripts=False # ) # # N = len(data_layer) # steps_per_epoch = math.ceil( # N / (args.batch_size * args.iter_per_step * args.num_gpus) # ) # logging.info("Have {0} examples to train on.".format(N)) # data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor( sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"] ) # multiply_batch_config = jasper_params.get("MultiplyBatch", None) # if multiply_batch_config: # multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config) # # spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None) # if spectr_augment_config: # data_spectr_augmentation = nemo_asr.SpectrogramAugmentation( # **spectr_augment_config # ) # eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"]) if args.remote_data: eval_dl_params["rpyc_host"] = args.remote_data del eval_dl_params["train"] del eval_dl_params["eval"] data_layers_eval = [] # if args.eval_datasets: for eval_datasets in args.eval_datasets: data_layer_eval = data_loader_layer( manifest_filepath=eval_datasets, sample_rate=sample_rate, labels=vocab, batch_size=args.eval_batch_size, num_workers=cpu_per_traindl, **eval_dl_params, ) data_layers_eval.append(data_layer_eval) # else: # logging.warning("There were no val datasets passed") jasper_encoder = nemo_asr.JasperEncoder( feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"], **jasper_params["JasperEncoder"], ) jasper_encoder.restore_from(args.encoder_checkpoint, local_rank=0) jasper_decoder = nemo_asr.JasperDecoderForCTC( feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], num_classes=len(vocab), ) jasper_decoder.restore_from(args.decoder_checkpoint, local_rank=0) ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab)) greedy_decoder = nemo_asr.GreedyCTCDecoder() # logging.info("================================") # logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}") # logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}") # logging.info( # f"Total number of parameters in model: " # f"{jasper_decoder.num_weights + jasper_encoder.num_weights}" # ) # logging.info("================================") # # # Train DAG # (audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer() # processed_signal_t, p_length_t = data_preprocessor( # input_signal=audio_signal_t, length=a_sig_length_t # ) # # if multiply_batch_config: # ( # processed_signal_t, # p_length_t, # transcript_t, # transcript_len_t, # ) = multiply_batch( # in_x=processed_signal_t, # in_x_len=p_length_t, # in_y=transcript_t, # in_y_len=transcript_len_t, # ) # # if spectr_augment_config: # processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t) # # encoded_t, encoded_len_t = jasper_encoder( # audio_signal=processed_signal_t, length=p_length_t # ) # log_probs_t = jasper_decoder(encoder_output=encoded_t) # predictions_t = greedy_decoder(log_probs=log_probs_t) # loss_t = ctc_loss( # log_probs=log_probs_t, # targets=transcript_t, # input_length=encoded_len_t, # target_length=transcript_len_t, # ) # # # Callbacks needed to print info to console and Tensorboard # train_callback = nemo.core.SimpleLossLoggerCallback( # tensors=[loss_t, predictions_t, transcript_t, transcript_len_t], # print_func=partial(monitor_asr_train_progress, labels=vocab), # get_tb_values=lambda x: [("loss", x[0])], # tb_writer=neural_factory.tb_writer, # ) # # chpt_callback = nemo.core.CheckpointCallback( # folder=neural_factory.checkpoint_dir, # load_from_folder=args.load_dir, # step_freq=args.checkpoint_save_freq, # checkpoints_to_keep=30, # ) # # callbacks = [train_callback, chpt_callback] callbacks = [] # assemble eval DAGs for i, eval_dl in enumerate(data_layers_eval): (audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl() processed_signal_e, p_length_e = data_preprocessor( input_signal=audio_signal_e, length=a_sig_length_e ) encoded_e, encoded_len_e = jasper_encoder( audio_signal=processed_signal_e, length=p_length_e ) log_probs_e = jasper_decoder(encoder_output=encoded_e) predictions_e = greedy_decoder(log_probs=log_probs_e) loss_e = ctc_loss( log_probs=log_probs_e, targets=transcript_e, input_length=encoded_len_e, target_length=transcript_len_e, ) # create corresponding eval callback tagname = os.path.basename(args.eval_datasets[i]).split(".")[0] eval_callback = nemo.core.EvaluatorCallback( eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e], user_iter_callback=partial(process_evaluation_batch, labels=vocab), user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname), eval_step=args.eval_freq, tb_writer=neural_factory.tb_writer, ) callbacks.append(eval_callback) return callbacks def main(): args = parse_args() # name = construct_name( # args.exp_name, # args.lr, # args.batch_size, # args.max_steps, # args.num_epochs, # args.weight_decay, # args.optimizer, # args.iter_per_step, # ) # log_dir = name # if args.work_dir: # log_dir = os.path.join(args.work_dir, name) # instantiate Neural Factory with supported backend neural_factory = nemo.core.NeuralModuleFactory( placement=nemo.core.DeviceType.GPU, backend=nemo.core.Backend.PyTorch, # local_rank=args.local_rank, # optimization_level=args.amp_opt_level, # log_dir=log_dir, # checkpoint_dir=args.checkpoint_dir, # create_tb_writer=args.create_tb_writer, # files_to_copy=[args.model_config, __file__], # cudnn_benchmark=args.cudnn_benchmark, # tensorboard_dir=args.tensorboard_dir, ) args.num_gpus = neural_factory.world_size # checkpoint_dir = neural_factory.checkpoint_dir if args.local_rank is not None: logging.info("Doing ALL GPU") # build dags callbacks = create_all_dags(args, neural_factory) # evaluate model neural_factory.eval(callbacks=callbacks) if __name__ == "__main__": main()