1
0
mirror of https://github.com/malarinv/jasper-asr.git synced 2026-03-08 10:32:35 +00:00

refactored module structure

This commit is contained in:
2020-05-21 16:47:45 +05:30
parent 2d5b720284
commit fca9c1aeb3
23 changed files with 17 additions and 115 deletions

View File

@@ -0,0 +1 @@

368
jasper/training/cli.py Normal file
View File

@@ -0,0 +1,368 @@
# Copyright (c) 2019 NVIDIA Corporation
import argparse
import copy
import math
import os
from pathlib import Path
from functools import partial
from ruamel.yaml import YAML
import nemo
import nemo.collections.asr as nemo_asr
import nemo.utils.argparse as nm_argparse
from nemo.collections.asr.helpers import (
monitor_asr_train_progress,
process_evaluation_batch,
process_evaluation_epoch,
)
from nemo.utils.lr_policies import CosineAnnealing
from .data_loaders import RpycAudioToTextDataLayer
logging = nemo.logging
def parse_args():
parser = argparse.ArgumentParser(
parents=[nm_argparse.NemoArgParser()],
description="Jasper",
conflict_handler="resolve",
)
parser.set_defaults(
checkpoint_dir=None,
optimizer="novograd",
batch_size=64,
eval_batch_size=64,
lr=0.002,
amp_opt_level="O1",
create_tb_writer=True,
model_config="./train/jasper-speller10x5dr.yaml",
# train_dataset="./train/asr_data/train_manifest.json",
# eval_datasets="./train/asr_data/test_manifest.json",
work_dir="./train/work",
num_epochs=300,
weight_decay=0.005,
checkpoint_save_freq=200,
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
exp_name="jasper-speller",
)
# Overwrite default args
parser.add_argument(
"--max_steps",
type=int,
default=None,
required=False,
help="max number of steps to train",
)
parser.add_argument(
"--num_epochs",
type=int,
default=None,
required=False,
help="number of epochs to train",
)
parser.add_argument(
"--model_config",
type=str,
required=False,
help="model configuration file: model.yaml",
)
parser.add_argument(
"--remote_data",
type=str,
required=False,
default="",
help="remote dataloader endpoint",
)
parser.add_argument(
"--dataset",
type=str,
required=False,
default="",
help="dataset directory containing train/test manifests",
)
# Create new args
parser.add_argument("--exp_name", default="Jasper", type=str)
parser.add_argument("--beta1", default=0.95, type=float)
parser.add_argument("--beta2", default=0.25, type=float)
parser.add_argument("--warmup_steps", default=0, type=int)
parser.add_argument(
"--load_dir",
default=None,
type=str,
help="directory with pre-trained checkpoint",
)
args = parser.parse_args()
if args.max_steps is None and args.num_epochs is None:
raise ValueError("Either max_steps or num_epochs should be provided.")
return args
def construct_name(
name, lr, batch_size, max_steps, num_epochs, wd, optimizer, iter_per_step
):
if max_steps is not None:
return "{0}-lr_{1}-bs_{2}-s_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr, batch_size, max_steps, wd, optimizer, iter_per_step
)
else:
return "{0}-lr_{1}-bs_{2}-e_{3}-wd_{4}-opt_{5}-ips_{6}".format(
name, lr, batch_size, num_epochs, wd, optimizer, iter_per_step
)
def create_all_dags(args, neural_factory):
yaml = YAML(typ="safe")
with open(args.model_config) as f:
jasper_params = yaml.load(f)
vocab = jasper_params["labels"]
sample_rate = jasper_params["sample_rate"]
# Calculate num_workers for dataloader
total_cpus = os.cpu_count()
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
# perturb_config = jasper_params.get('perturb', None)
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
del train_dl_params["train"]
del train_dl_params["eval"]
# del train_dl_params["normalize_transcripts"]
if args.dataset:
d_path = Path(args.dataset)
if not args.train_dataset:
args.train_dataset = str(d_path / Path("train_manifest.json"))
if not args.eval_datasets:
args.eval_datasets = [str(d_path / Path("test_manifest.json"))]
data_loader_layer = nemo_asr.AudioToTextDataLayer
if args.remote_data:
train_dl_params["rpyc_host"] = args.remote_data
data_loader_layer = RpycAudioToTextDataLayer
data_layer = data_loader_layer(
manifest_filepath=args.train_dataset,
sample_rate=sample_rate,
labels=vocab,
batch_size=args.batch_size,
num_workers=cpu_per_traindl,
**train_dl_params,
# normalize_transcripts=False
)
N = len(data_layer)
steps_per_epoch = math.ceil(
N / (args.batch_size * args.iter_per_step * args.num_gpus)
)
logging.info("Have {0} examples to train on.".format(N))
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
)
multiply_batch_config = jasper_params.get("MultiplyBatch", None)
if multiply_batch_config:
multiply_batch = nemo_asr.MultiplyBatch(**multiply_batch_config)
spectr_augment_config = jasper_params.get("SpectrogramAugmentation", None)
if spectr_augment_config:
data_spectr_augmentation = nemo_asr.SpectrogramAugmentation(
**spectr_augment_config
)
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
if args.remote_data:
eval_dl_params["rpyc_host"] = args.remote_data
del eval_dl_params["train"]
del eval_dl_params["eval"]
data_layers_eval = []
if args.eval_datasets:
for eval_datasets in args.eval_datasets:
data_layer_eval = data_loader_layer(
manifest_filepath=eval_datasets,
sample_rate=sample_rate,
labels=vocab,
batch_size=args.eval_batch_size,
num_workers=cpu_per_traindl,
**eval_dl_params,
)
data_layers_eval.append(data_layer_eval)
else:
logging.warning("There were no val datasets passed")
jasper_encoder = nemo_asr.JasperEncoder(
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"],
**jasper_params["JasperEncoder"],
)
jasper_decoder = nemo_asr.JasperDecoderForCTC(
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"],
num_classes=len(vocab),
)
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
greedy_decoder = nemo_asr.GreedyCTCDecoder()
logging.info("================================")
logging.info(f"Number of parameters in encoder: {jasper_encoder.num_weights}")
logging.info(f"Number of parameters in decoder: {jasper_decoder.num_weights}")
logging.info(
f"Total number of parameters in model: "
f"{jasper_decoder.num_weights + jasper_encoder.num_weights}"
)
logging.info("================================")
# Train DAG
(audio_signal_t, a_sig_length_t, transcript_t, transcript_len_t) = data_layer()
processed_signal_t, p_length_t = data_preprocessor(
input_signal=audio_signal_t, length=a_sig_length_t
)
if multiply_batch_config:
(
processed_signal_t,
p_length_t,
transcript_t,
transcript_len_t,
) = multiply_batch(
in_x=processed_signal_t,
in_x_len=p_length_t,
in_y=transcript_t,
in_y_len=transcript_len_t,
)
if spectr_augment_config:
processed_signal_t = data_spectr_augmentation(input_spec=processed_signal_t)
encoded_t, encoded_len_t = jasper_encoder(
audio_signal=processed_signal_t, length=p_length_t
)
log_probs_t = jasper_decoder(encoder_output=encoded_t)
predictions_t = greedy_decoder(log_probs=log_probs_t)
loss_t = ctc_loss(
log_probs=log_probs_t,
targets=transcript_t,
input_length=encoded_len_t,
target_length=transcript_len_t,
)
# Callbacks needed to print info to console and Tensorboard
train_callback = nemo.core.SimpleLossLoggerCallback(
tensors=[loss_t, predictions_t, transcript_t, transcript_len_t],
print_func=partial(monitor_asr_train_progress, labels=vocab),
get_tb_values=lambda x: [("loss", x[0])],
tb_writer=neural_factory.tb_writer,
)
chpt_callback = nemo.core.CheckpointCallback(
folder=neural_factory.checkpoint_dir,
load_from_folder=args.load_dir,
step_freq=args.checkpoint_save_freq,
)
callbacks = [train_callback, chpt_callback]
# assemble eval DAGs
for i, eval_dl in enumerate(data_layers_eval):
(audio_signal_e, a_sig_length_e, transcript_e, transcript_len_e) = eval_dl()
processed_signal_e, p_length_e = data_preprocessor(
input_signal=audio_signal_e, length=a_sig_length_e
)
encoded_e, encoded_len_e = jasper_encoder(
audio_signal=processed_signal_e, length=p_length_e
)
log_probs_e = jasper_decoder(encoder_output=encoded_e)
predictions_e = greedy_decoder(log_probs=log_probs_e)
loss_e = ctc_loss(
log_probs=log_probs_e,
targets=transcript_e,
input_length=encoded_len_e,
target_length=transcript_len_e,
)
# create corresponding eval callback
tagname = os.path.basename(args.eval_datasets[i]).split(".")[0]
eval_callback = nemo.core.EvaluatorCallback(
eval_tensors=[loss_e, predictions_e, transcript_e, transcript_len_e],
user_iter_callback=partial(process_evaluation_batch, labels=vocab),
user_epochs_done_callback=partial(process_evaluation_epoch, tag=tagname),
eval_step=args.eval_freq,
tb_writer=neural_factory.tb_writer,
)
callbacks.append(eval_callback)
return loss_t, callbacks, steps_per_epoch
def main():
args = parse_args()
name = construct_name(
args.exp_name,
args.lr,
args.batch_size,
args.max_steps,
args.num_epochs,
args.weight_decay,
args.optimizer,
args.iter_per_step,
)
log_dir = name
if args.work_dir:
log_dir = os.path.join(args.work_dir, name)
# instantiate Neural Factory with supported backend
neural_factory = nemo.core.NeuralModuleFactory(
backend=nemo.core.Backend.PyTorch,
local_rank=args.local_rank,
optimization_level=args.amp_opt_level,
log_dir=log_dir,
checkpoint_dir=args.checkpoint_dir,
create_tb_writer=args.create_tb_writer,
files_to_copy=[args.model_config, __file__],
cudnn_benchmark=args.cudnn_benchmark,
tensorboard_dir=args.tensorboard_dir,
)
args.num_gpus = neural_factory.world_size
checkpoint_dir = neural_factory.checkpoint_dir
if args.local_rank is not None:
logging.info("Doing ALL GPU")
# build dags
train_loss, callbacks, steps_per_epoch = create_all_dags(args, neural_factory)
# train model
neural_factory.train(
tensors_to_optimize=[train_loss],
callbacks=callbacks,
lr_policy=CosineAnnealing(
args.max_steps
if args.max_steps is not None
else args.num_epochs * steps_per_epoch,
warmup_steps=args.warmup_steps,
),
optimizer=args.optimizer,
optimization_params={
"num_epochs": args.num_epochs,
"max_steps": args.max_steps,
"lr": args.lr,
"betas": (args.beta1, args.beta2),
"weight_decay": args.weight_decay,
"grad_norm_clip": None,
},
batches_per_step=args.iter_per_step,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,308 @@
from functools import partial
import tempfile
# from typing import Any, Dict, List, Optional
import torch
import nemo
# import nemo.collections.asr as nemo_asr
from nemo.backends.pytorch import DataLayerNM
from nemo.core import DeviceType
# from nemo.core.neural_types import *
from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType
from nemo.utils.decorators import add_port_docs
from nemo.collections.asr.parts.dataset import (
# AudioDataset,
# AudioLabelDataset,
# KaldiFeatureDataset,
# TranscriptDataset,
parsers,
collections,
seq_collate_fn,
)
# from functools import lru_cache
import rpyc
from .featurizer import RpycWaveformFeaturizer
# from nemo.collections.asr.parts.features import WaveformFeaturizer
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
logging = nemo.logging
class CachedAudioDataset(torch.utils.data.Dataset):
"""
Dataset that loads tensors via a json file containing paths to audio
files, transcripts, and durations (in seconds). Each new line is a
different sample. Example below:
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
"/path/to/audio.txt", "duration": 23.147}
...
{"audio_filepath": "/path/to/audio.wav", "text": "the
transcription", offset": 301.75, "duration": 0.82, "utt":
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
Args:
manifest_filepath: Path to manifest json as described above. Can
be comma-separated paths.
labels: String containing all the possible characters to map to
featurizer: Initialized featurizer class that converts paths of
audio to feature tensors
max_duration: If audio exceeds this length, do not include in dataset
min_duration: If audio is less than this length, do not include
in dataset
max_utts: Limit number of utterances
blank_index: blank character index, default = -1
unk_index: unk_character index, default = -1
normalize: whether to normalize transcript text (default): True
bos_id: Id of beginning of sequence symbol to append if not None
eos_id: Id of end of sequence symbol to append if not None
load_audio: Boolean flag indicate whether do or not load audio
"""
def __init__(
self,
manifest_filepath,
labels,
featurizer,
max_duration=None,
min_duration=None,
max_utts=0,
blank_index=-1,
unk_index=-1,
normalize=True,
trim=False,
bos_id=None,
eos_id=None,
load_audio=True,
parser='en',
):
self.collection = collections.ASRAudioText(
manifests_files=manifest_filepath.split(','),
parser=parsers.make_parser(
labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize,
),
min_duration=min_duration,
max_duration=max_duration,
max_number=max_utts,
)
self.index_feature_map = {}
self.featurizer = featurizer
self.trim = trim
self.eos_id = eos_id
self.bos_id = bos_id
self.load_audio = load_audio
print(f'initializing dataset {manifest_filepath}')
for i in range(len(self.collection)):
self[i]
print(f'initializing complete')
def __getitem__(self, index):
sample = self.collection[index]
if self.load_audio:
cached_features = self.index_feature_map.get(index)
if cached_features is not None:
features = cached_features
else:
features = self.featurizer.process(sample.audio_file, offset=0, duration=sample.duration, trim=self.trim,)
self.index_feature_map[index] = features
f, fl = features, torch.tensor(features.shape[0]).long()
else:
f, fl = None, None
t, tl = sample.text_tokens, len(sample.text_tokens)
if self.bos_id is not None:
t = [self.bos_id] + t
tl += 1
if self.eos_id is not None:
t = t + [self.eos_id]
tl += 1
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
def __len__(self):
return len(self.collection)
class RpycAudioToTextDataLayer(DataLayerNM):
"""Data Layer for general ASR tasks.
Module which reads ASR labeled data. It accepts comma-separated
JSON manifest files describing the correspondence between wav audio files
and their transcripts. JSON files should be of the following format::
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
transcript_0}
...
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
transcript_n}
Args:
manifest_filepath (str): Dataset parameter.
Path to JSON containing data.
labels (list): Dataset parameter.
List of characters that can be output by the ASR model.
For Jasper, this is the 28 character set {a-z '}. The CTC blank
symbol is automatically added later for models using ctc.
batch_size (int): batch size
sample_rate (int): Target sampling rate for data. Audio files will be
resampled to sample_rate if it is not already.
Defaults to 16000.
int_values (bool): Bool indicating whether the audio file is saved as
int data or float data.
Defaults to False.
eos_id (id): Dataset parameter.
End of string symbol id used for seq2seq models.
Defaults to None.
min_duration (float): Dataset parameter.
All training files which have a duration less than min_duration
are dropped. Note: Duration is read from the manifest JSON.
Defaults to 0.1.
max_duration (float): Dataset parameter.
All training files which have a duration more than max_duration
are dropped. Note: Duration is read from the manifest JSON.
Defaults to None.
normalize_transcripts (bool): Dataset parameter.
Whether to use automatic text cleaning.
It is highly recommended to manually clean text for best results.
Defaults to True.
trim_silence (bool): Whether to use trim silence from beginning and end
of audio signal using librosa.effects.trim().
Defaults to False.
load_audio (bool): Dataset parameter.
Controls whether the dataloader loads the audio signal and
transcript or just the transcript.
Defaults to True.
drop_last (bool): See PyTorch DataLoader.
Defaults to False.
shuffle (bool): See PyTorch DataLoader.
Defaults to True.
num_workers (int): See PyTorch DataLoader.
Defaults to 0.
perturb_config (dict): Currently disabled.
"""
@property
@add_port_docs()
def output_ports(self):
"""Returns definitions of module output ports.
"""
return {
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
"audio_signal": NeuralType(
("B", "T"),
AudioSignal(freq=self._sample_rate)
if self is not None and self._sample_rate is not None
else AudioSignal(),
),
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
"transcripts": NeuralType(("B", "T"), LabelsType()),
"transcript_length": NeuralType(tuple("B"), LengthsType()),
}
def __init__(
self,
manifest_filepath,
labels,
batch_size,
sample_rate=16000,
int_values=False,
bos_id=None,
eos_id=None,
pad_id=None,
min_duration=0.1,
max_duration=None,
normalize_transcripts=True,
trim_silence=False,
load_audio=True,
rpyc_host="",
drop_last=False,
shuffle=True,
num_workers=0,
):
super().__init__()
self._sample_rate = sample_rate
def rpyc_root_fn():
return rpyc.connect(
rpyc_host, 8064, config={"sync_request_timeout": 600}
).root
rpyc_conn = rpyc_root_fn()
self._featurizer = RpycWaveformFeaturizer(
sample_rate=self._sample_rate,
int_values=int_values,
augmentor=None,
rpyc_conn=rpyc_conn,
)
def read_remote_manifests():
local_mp = []
for mrp in manifest_filepath.split(","):
md = rpyc_conn.read_path(mrp)
mf = tempfile.NamedTemporaryFile(
dir="/tmp", prefix="jasper_manifest.", delete=False
)
mf.write(md)
mf.close()
local_mp.append(mf.name)
return ",".join(local_mp)
local_manifest_filepath = read_remote_manifests()
dataset_params = {
"manifest_filepath": local_manifest_filepath,
"labels": labels,
"featurizer": self._featurizer,
"max_duration": max_duration,
"min_duration": min_duration,
"normalize": normalize_transcripts,
"trim": trim_silence,
"bos_id": bos_id,
"eos_id": eos_id,
"load_audio": load_audio,
}
self._dataset = CachedAudioDataset(**dataset_params)
self._batch_size = batch_size
# Set up data loader
if self._placement == DeviceType.AllGpu:
logging.info("Parallelizing Datalayer.")
sampler = torch.utils.data.distributed.DistributedSampler(self._dataset)
else:
sampler = None
if batch_size == -1:
batch_size = len(self._dataset)
pad_id = 0 if pad_id is None else pad_id
self._dataloader = torch.utils.data.DataLoader(
dataset=self._dataset,
batch_size=batch_size,
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
drop_last=drop_last,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=1,
)
def __len__(self):
return len(self._dataset)
@property
def dataset(self):
return None
@property
def data_iterator(self):
return self._dataloader

View File

@@ -0,0 +1,51 @@
# import math
# import librosa
import torch
import pickle
# import torch.nn as nn
# from torch_stft import STFT
# from nemo import logging
from nemo.collections.asr.parts.perturb import AudioAugmentor
# from nemo.collections.asr.parts.segment import AudioSegment
class RpycWaveformFeaturizer(object):
def __init__(
self, sample_rate=16000, int_values=False, augmentor=None, rpyc_conn=None
):
self.augmentor = augmentor if augmentor is not None else AudioAugmentor()
self.sample_rate = sample_rate
self.int_values = int_values
self.remote_path_samples = rpyc_conn.get_path_samples
def max_augmentation_length(self, length):
return self.augmentor.max_augmentation_length(length)
def process(self, file_path, offset=0, duration=0, trim=False):
audio = self.remote_path_samples(
file_path,
target_sr=self.sample_rate,
int_values=self.int_values,
offset=offset,
duration=duration,
trim=trim,
)
return torch.tensor(pickle.loads(audio), dtype=torch.float)
def process_segment(self, audio_segment):
self.augmentor.perturb(audio_segment)
return torch.tensor(audio_segment, dtype=torch.float)
@classmethod
def from_config(cls, input_config, perturbation_configs=None):
if perturbation_configs is not None:
aa = AudioAugmentor.from_config(perturbation_configs)
else:
aa = None
sample_rate = input_config.get("sample_rate", 16000)
int_values = input_config.get("int_values", False)
return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa)