1. added training utils with custom data loaders with remote rpyc dataservice support

2. fix validation correction dump path
3. cache dataset for precaching before training to memory
4. update dependencies
Malar Kannan 2020-05-14 15:39:44 +05:30
parent d4aef4088d
commit 83db445a6f
7 changed files with 419 additions and 13 deletions

View File

@ -1,24 +1,50 @@
import os
# from pathlib import Path
from pathlib import Path
import typer
import rpyc
from rpyc.utils.server import ThreadedServer
import nemo.collections.asr as nemo_asr
import nemo
import pickle
# import nemo.collections.asr as nemo_asr
from nemo.collections.asr.parts.segment import AudioSegment
app = typer.Typer()
nemo.core.NeuralModuleFactory(
backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU
)
class ASRDataService(rpyc.Service):
def get_data_loader(self):
return nemo_asr.AudioToTextDataLayer
def exposed_get_path_samples(
self, file_path, target_sr, int_values, offset, duration, trim
):
print(f"loading.. {file_path}")
audio = AudioSegment.from_file(
file_path,
target_sr=target_sr,
int_values=int_values,
offset=offset,
duration=duration,
trim=trim,
)
# print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}")
return pickle.dumps(audio.samples)
def exposed_read_path(self, file_path):
# print(f"reading path.. {file_path}")
return Path(file_path).read_bytes()
@app.command()
def run_server(port: int = 0):
listen_port = port if port else int(os.environ.get("ASR_RPYC_PORT", "8044"))
listen_port = port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064"))
service = ASRDataService()
t = ThreadedServer(service, port=listen_port)
t = ThreadedServer(
service, port=listen_port, protocol_config={"allow_all_attrs": True}
)
typer.echo(f"starting asr server on {listen_port}...")
t.start()

View File

@ -113,7 +113,7 @@ def dump_validation_ui_data(
@app.command()
def dump_corrections(dump_path: Path = Path("./data/corrections.json")):
def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")):
col = get_mongo_conn().test.asr_validation
cursor_obj = col.find({"type": "correction"}, projection={"_id": False})

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,308 @@
from functools import partial
import tempfile
# from typing import Any, Dict, List, Optional
import torch
import nemo
# import nemo.collections.asr as nemo_asr
from nemo.backends.pytorch import DataLayerNM
from nemo.core import DeviceType
# from nemo.core.neural_types import *
from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType
from nemo.utils.decorators import add_port_docs
from nemo.collections.asr.parts.dataset import (
# AudioDataset,
# AudioLabelDataset,
# KaldiFeatureDataset,
# TranscriptDataset,
parsers,
collections,
seq_collate_fn,
)
# from functools import lru_cache
import rpyc
from .featurizer import RpycWaveformFeaturizer
# from nemo.collections.asr.parts.features import WaveformFeaturizer
# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types
logging = nemo.logging
class CachedAudioDataset(torch.utils.data.Dataset):
"""
Dataset that loads tensors via a json file containing paths to audio
files, transcripts, and durations (in seconds). Each new line is a
different sample. Example below:
{"audio_filepath": "/path/to/audio.wav", "text_filepath":
"/path/to/audio.txt", "duration": 23.147}
...
{"audio_filepath": "/path/to/audio.wav", "text": "the
transcription", offset": 301.75, "duration": 0.82, "utt":
"utterance_id", "ctm_utt": "en_4156", "side": "A"}
Args:
manifest_filepath: Path to manifest json as described above. Can
be comma-separated paths.
labels: String containing all the possible characters to map to
featurizer: Initialized featurizer class that converts paths of
audio to feature tensors
max_duration: If audio exceeds this length, do not include in dataset
min_duration: If audio is less than this length, do not include
in dataset
max_utts: Limit number of utterances
blank_index: blank character index, default = -1
unk_index: unk_character index, default = -1
normalize: whether to normalize transcript text (default): True
bos_id: Id of beginning of sequence symbol to append if not None
eos_id: Id of end of sequence symbol to append if not None
load_audio: Boolean flag indicate whether do or not load audio
"""
def __init__(
self,
manifest_filepath,
labels,
featurizer,
max_duration=None,
min_duration=None,
max_utts=0,
blank_index=-1,
unk_index=-1,
normalize=True,
trim=False,
bos_id=None,
eos_id=None,
load_audio=True,
parser='en',
):
self.collection = collections.ASRAudioText(
manifests_files=manifest_filepath.split(','),
parser=parsers.make_parser(
labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize,
),
min_duration=min_duration,
max_duration=max_duration,
max_number=max_utts,
)
self.index_feature_map = {}
self.featurizer = featurizer
self.trim = trim
self.eos_id = eos_id
self.bos_id = bos_id
self.load_audio = load_audio
print(f'initializing dataset {manifest_filepath}')
for i in range(len(self.collection)):
self[i]
print(f'initializing complete')
def __getitem__(self, index):
sample = self.collection[index]
if self.load_audio:
cached_features = self.index_feature_map.get(index)
if cached_features is not None:
features = cached_features
else:
features = self.featurizer.process(sample.audio_file, offset=0, duration=sample.duration, trim=self.trim,)
self.index_feature_map[index] = features
f, fl = features, torch.tensor(features.shape[0]).long()
else:
f, fl = None, None
t, tl = sample.text_tokens, len(sample.text_tokens)
if self.bos_id is not None:
t = [self.bos_id] + t
tl += 1
if self.eos_id is not None:
t = t + [self.eos_id]
tl += 1
return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()
def __len__(self):
return len(self.collection)
class RpycAudioToTextDataLayer(DataLayerNM):
"""Data Layer for general ASR tasks.
Module which reads ASR labeled data. It accepts comma-separated
JSON manifest files describing the correspondence between wav audio files
and their transcripts. JSON files should be of the following format::
{"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \
transcript_0}
...
{"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \
transcript_n}
Args:
manifest_filepath (str): Dataset parameter.
Path to JSON containing data.
labels (list): Dataset parameter.
List of characters that can be output by the ASR model.
For Jasper, this is the 28 character set {a-z '}. The CTC blank
symbol is automatically added later for models using ctc.
batch_size (int): batch size
sample_rate (int): Target sampling rate for data. Audio files will be
resampled to sample_rate if it is not already.
Defaults to 16000.
int_values (bool): Bool indicating whether the audio file is saved as
int data or float data.
Defaults to False.
eos_id (id): Dataset parameter.
End of string symbol id used for seq2seq models.
Defaults to None.
min_duration (float): Dataset parameter.
All training files which have a duration less than min_duration
are dropped. Note: Duration is read from the manifest JSON.
Defaults to 0.1.
max_duration (float): Dataset parameter.
All training files which have a duration more than max_duration
are dropped. Note: Duration is read from the manifest JSON.
Defaults to None.
normalize_transcripts (bool): Dataset parameter.
Whether to use automatic text cleaning.
It is highly recommended to manually clean text for best results.
Defaults to True.
trim_silence (bool): Whether to use trim silence from beginning and end
of audio signal using librosa.effects.trim().
Defaults to False.
load_audio (bool): Dataset parameter.
Controls whether the dataloader loads the audio signal and
transcript or just the transcript.
Defaults to True.
drop_last (bool): See PyTorch DataLoader.
Defaults to False.
shuffle (bool): See PyTorch DataLoader.
Defaults to True.
num_workers (int): See PyTorch DataLoader.
Defaults to 0.
perturb_config (dict): Currently disabled.
"""
@property
@add_port_docs()
def output_ports(self):
"""Returns definitions of module output ports.
"""
return {
# 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
# 'a_sig_length': NeuralType({0: AxisType(BatchTag)}),
# 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}),
# 'transcript_length': NeuralType({0: AxisType(BatchTag)}),
"audio_signal": NeuralType(
("B", "T"),
AudioSignal(freq=self._sample_rate)
if self is not None and self._sample_rate is not None
else AudioSignal(),
),
"a_sig_length": NeuralType(tuple("B"), LengthsType()),
"transcripts": NeuralType(("B", "T"), LabelsType()),
"transcript_length": NeuralType(tuple("B"), LengthsType()),
}
def __init__(
self,
manifest_filepath,
labels,
batch_size,
sample_rate=16000,
int_values=False,
bos_id=None,
eos_id=None,
pad_id=None,
min_duration=0.1,
max_duration=None,
normalize_transcripts=True,
trim_silence=False,
load_audio=True,
rpyc_host="",
drop_last=False,
shuffle=True,
num_workers=0,
):
super().__init__()
self._sample_rate = sample_rate
def rpyc_root_fn():
return rpyc.connect(
rpyc_host, 8064, config={"sync_request_timeout": 600}
).root
rpyc_conn = rpyc_root_fn()
self._featurizer = RpycWaveformFeaturizer(
sample_rate=self._sample_rate,
int_values=int_values,
augmentor=None,
rpyc_conn=rpyc_conn,
)
def read_remote_manifests():
local_mp = []
for mrp in manifest_filepath.split(","):
md = rpyc_conn.read_path(mrp)
mf = tempfile.NamedTemporaryFile(
dir="/tmp", prefix="jasper_manifest.", delete=False
)
mf.write(md)
mf.close()
local_mp.append(mf.name)
return ",".join(local_mp)
local_manifest_filepath = read_remote_manifests()
dataset_params = {
"manifest_filepath": local_manifest_filepath,
"labels": labels,
"featurizer": self._featurizer,
"max_duration": max_duration,
"min_duration": min_duration,
"normalize": normalize_transcripts,
"trim": trim_silence,
"bos_id": bos_id,
"eos_id": eos_id,
"load_audio": load_audio,
}
self._dataset = CachedAudioDataset(**dataset_params)
self._batch_size = batch_size
# Set up data loader
if self._placement == DeviceType.AllGpu:
logging.info("Parallelizing Datalayer.")
sampler = torch.utils.data.distributed.DistributedSampler(self._dataset)
else:
sampler = None
if batch_size == -1:
batch_size = len(self._dataset)
pad_id = 0 if pad_id is None else pad_id
self._dataloader = torch.utils.data.DataLoader(
dataset=self._dataset,
batch_size=batch_size,
collate_fn=partial(seq_collate_fn, token_pad_value=pad_id),
drop_last=drop_last,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=1,
)
def __len__(self):
return len(self._dataset)
@property
def dataset(self):
return None
@property
def data_iterator(self):
return self._dataloader

View File

@ -0,0 +1,51 @@
# import math
# import librosa
import torch
import pickle
# import torch.nn as nn
# from torch_stft import STFT
# from nemo import logging
from nemo.collections.asr.parts.perturb import AudioAugmentor
# from nemo.collections.asr.parts.segment import AudioSegment
class RpycWaveformFeaturizer(object):
def __init__(
self, sample_rate=16000, int_values=False, augmentor=None, rpyc_conn=None
):
self.augmentor = augmentor if augmentor is not None else AudioAugmentor()
self.sample_rate = sample_rate
self.int_values = int_values
self.remote_path_samples = rpyc_conn.get_path_samples
def max_augmentation_length(self, length):
return self.augmentor.max_augmentation_length(length)
def process(self, file_path, offset=0, duration=0, trim=False):
audio = self.remote_path_samples(
file_path,
target_sr=self.sample_rate,
int_values=self.int_values,
offset=offset,
duration=duration,
trim=trim,
)
return torch.tensor(pickle.loads(audio), dtype=torch.float)
def process_segment(self, audio_segment):
self.augmentor.perturb(audio_segment)
return torch.tensor(audio_segment, dtype=torch.float)
@classmethod
def from_config(cls, input_config, perturbation_configs=None):
if perturbation_configs is not None:
aa = AudioAugmentor.from_config(perturbation_configs)
else:
aa = None
sample_rate = input_config.get("sample_rate", 16000)
int_values = input_config.get("int_values", False)
return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa)

View File

@ -15,7 +15,9 @@ from nemo.collections.asr.helpers import (
process_evaluation_batch,
process_evaluation_epoch,
)
from nemo.utils.lr_policies import CosineAnnealing
from .data_loaders import RpycAudioToTextDataLayer
logging = nemo.logging
@ -44,7 +46,7 @@ def parse_args():
eval_freq=100,
load_dir="./train/models/jasper/",
warmup_steps=3,
exp_name='jasper-speller'
exp_name="jasper-speller",
)
# Overwrite default args
@ -69,6 +71,14 @@ def parse_args():
help="model configuration file: model.yaml",
)
parser.add_argument(
"--remote_data",
type=str,
required=False,
default="",
help="remote dataloader endpoint",
)
# Create new args
parser.add_argument("--exp_name", default="Jasper", type=str)
parser.add_argument("--beta1", default=0.95, type=float)
@ -110,15 +120,23 @@ def create_all_dags(args, neural_factory):
# Calculate num_workers for dataloader
total_cpus = os.cpu_count()
cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1)
# cpu_per_traindl = 1
# perturb_config = jasper_params.get('perturb', None)
train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"])
del train_dl_params["train"]
del train_dl_params["eval"]
# del train_dl_params["normalize_transcripts"]
data_layer = nemo_asr.AudioToTextDataLayer(
data_loader_layer = nemo_asr.AudioToTextDataLayer
if args.remote_data:
train_dl_params['rpyc_host'] = args.remote_data
data_loader_layer = RpycAudioToTextDataLayer
# if args.remote_data:
# # import pdb; pdb.set_trace()
# data_loader_layer = rpyc.connect(
# args.remote_data, 8064, config={"sync_request_timeout": 600}
# ).root.get_data_loader()
data_layer = data_loader_layer(
manifest_filepath=args.train_dataset,
sample_rate=sample_rate,
labels=vocab,
@ -150,13 +168,15 @@ def create_all_dags(args, neural_factory):
eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"])
eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"])
if args.remote_data:
eval_dl_params['rpyc_host'] = args.remote_data
del eval_dl_params["train"]
del eval_dl_params["eval"]
data_layers_eval = []
if args.eval_datasets:
for eval_datasets in args.eval_datasets:
data_layer_eval = nemo_asr.AudioToTextDataLayer(
data_layer_eval = data_loader_layer(
manifest_filepath=eval_datasets,
sample_rate=sample_rate,
labels=vocab,

View File

@ -61,7 +61,7 @@ setup(
"console_scripts": [
"jasper_transcribe = jasper.transcribe:main",
"jasper_asr_rpyc_server = jasper.server:main",
"jasper_asr_trainer = jasper.train:main",
"jasper_asr_trainer = jasper.training_utils.train:main",
"jasper_asr_data_generate = jasper.data_utils.generator:main",
"jasper_asr_data_recycle = jasper.data_utils.call_recycler:main",
"jasper_asr_data_validation = jasper.data_utils.validation.process:main",