From 83db445a6fd359a377a0fd466dcea61f67537535 Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Thu, 14 May 2020 15:39:44 +0530 Subject: [PATCH] 1. added training utils with custom data loaders with remote rpyc dataservice support 2. fix validation correction dump path 3. cache dataset for precaching before training to memory 4. update dependencies --- jasper/data_utils/data_server.py | 38 ++- jasper/data_utils/validation/process.py | 2 +- jasper/training_utils/__init__.py | 1 + jasper/training_utils/data_loaders.py | 308 ++++++++++++++++++++++++ jasper/training_utils/featurizer.py | 51 ++++ jasper/{ => training_utils}/train.py | 30 ++- setup.py | 2 +- 7 files changed, 419 insertions(+), 13 deletions(-) create mode 100644 jasper/training_utils/__init__.py create mode 100644 jasper/training_utils/data_loaders.py create mode 100644 jasper/training_utils/featurizer.py rename jasper/{ => training_utils}/train.py (92%) diff --git a/jasper/data_utils/data_server.py b/jasper/data_utils/data_server.py index eecd848..856c381 100644 --- a/jasper/data_utils/data_server.py +++ b/jasper/data_utils/data_server.py @@ -1,24 +1,50 @@ import os -# from pathlib import Path +from pathlib import Path import typer import rpyc from rpyc.utils.server import ThreadedServer -import nemo.collections.asr as nemo_asr +import nemo +import pickle + +# import nemo.collections.asr as nemo_asr +from nemo.collections.asr.parts.segment import AudioSegment app = typer.Typer() +nemo.core.NeuralModuleFactory( + backend=nemo.core.Backend.PyTorch, placement=nemo.core.DeviceType.CPU +) + class ASRDataService(rpyc.Service): - def get_data_loader(self): - return nemo_asr.AudioToTextDataLayer + def exposed_get_path_samples( + self, file_path, target_sr, int_values, offset, duration, trim + ): + print(f"loading.. {file_path}") + audio = AudioSegment.from_file( + file_path, + target_sr=target_sr, + int_values=int_values, + offset=offset, + duration=duration, + trim=trim, + ) + # print(f"returning.. {len(audio.samples)} items of type{type(audio.samples)}") + return pickle.dumps(audio.samples) + + def exposed_read_path(self, file_path): + # print(f"reading path.. {file_path}") + return Path(file_path).read_bytes() @app.command() def run_server(port: int = 0): - listen_port = port if port else int(os.environ.get("ASR_RPYC_PORT", "8044")) + listen_port = port if port else int(os.environ.get("ASR_DARA_RPYC_PORT", "8064")) service = ASRDataService() - t = ThreadedServer(service, port=listen_port) + t = ThreadedServer( + service, port=listen_port, protocol_config={"allow_all_attrs": True} + ) typer.echo(f"starting asr server on {listen_port}...") t.start() diff --git a/jasper/data_utils/validation/process.py b/jasper/data_utils/validation/process.py index a2588aa..cdd1cbc 100644 --- a/jasper/data_utils/validation/process.py +++ b/jasper/data_utils/validation/process.py @@ -113,7 +113,7 @@ def dump_validation_ui_data( @app.command() -def dump_corrections(dump_path: Path = Path("./data/corrections.json")): +def dump_corrections(dump_path: Path = Path("./data/valiation_data/corrections.json")): col = get_mongo_conn().test.asr_validation cursor_obj = col.find({"type": "correction"}, projection={"_id": False}) diff --git a/jasper/training_utils/__init__.py b/jasper/training_utils/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/jasper/training_utils/__init__.py @@ -0,0 +1 @@ + diff --git a/jasper/training_utils/data_loaders.py b/jasper/training_utils/data_loaders.py new file mode 100644 index 0000000..5ffa3e4 --- /dev/null +++ b/jasper/training_utils/data_loaders.py @@ -0,0 +1,308 @@ +from functools import partial +import tempfile + +# from typing import Any, Dict, List, Optional + +import torch +import nemo + +# import nemo.collections.asr as nemo_asr +from nemo.backends.pytorch import DataLayerNM +from nemo.core import DeviceType + +# from nemo.core.neural_types import * +from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType, LabelsType +from nemo.utils.decorators import add_port_docs + +from nemo.collections.asr.parts.dataset import ( + # AudioDataset, + # AudioLabelDataset, + # KaldiFeatureDataset, + # TranscriptDataset, + parsers, + collections, + seq_collate_fn, +) + +# from functools import lru_cache +import rpyc +from .featurizer import RpycWaveformFeaturizer + +# from nemo.collections.asr.parts.features import WaveformFeaturizer + +# from nemo.collections.asr.parts.perturb import AudioAugmentor, perturbation_types + + +logging = nemo.logging + + +class CachedAudioDataset(torch.utils.data.Dataset): + """ + Dataset that loads tensors via a json file containing paths to audio + files, transcripts, and durations (in seconds). Each new line is a + different sample. Example below: + + {"audio_filepath": "/path/to/audio.wav", "text_filepath": + "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the + transcription", offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + + Args: + manifest_filepath: Path to manifest json as described above. Can + be comma-separated paths. + labels: String containing all the possible characters to map to + featurizer: Initialized featurizer class that converts paths of + audio to feature tensors + max_duration: If audio exceeds this length, do not include in dataset + min_duration: If audio is less than this length, do not include + in dataset + max_utts: Limit number of utterances + blank_index: blank character index, default = -1 + unk_index: unk_character index, default = -1 + normalize: whether to normalize transcript text (default): True + bos_id: Id of beginning of sequence symbol to append if not None + eos_id: Id of end of sequence symbol to append if not None + load_audio: Boolean flag indicate whether do or not load audio + """ + + def __init__( + self, + manifest_filepath, + labels, + featurizer, + max_duration=None, + min_duration=None, + max_utts=0, + blank_index=-1, + unk_index=-1, + normalize=True, + trim=False, + bos_id=None, + eos_id=None, + load_audio=True, + parser='en', + ): + self.collection = collections.ASRAudioText( + manifests_files=manifest_filepath.split(','), + parser=parsers.make_parser( + labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize, + ), + min_duration=min_duration, + max_duration=max_duration, + max_number=max_utts, + ) + self.index_feature_map = {} + + self.featurizer = featurizer + self.trim = trim + self.eos_id = eos_id + self.bos_id = bos_id + self.load_audio = load_audio + print(f'initializing dataset {manifest_filepath}') + for i in range(len(self.collection)): + self[i] + print(f'initializing complete') + + def __getitem__(self, index): + sample = self.collection[index] + if self.load_audio: + cached_features = self.index_feature_map.get(index) + if cached_features is not None: + features = cached_features + else: + features = self.featurizer.process(sample.audio_file, offset=0, duration=sample.duration, trim=self.trim,) + self.index_feature_map[index] = features + f, fl = features, torch.tensor(features.shape[0]).long() + else: + f, fl = None, None + + t, tl = sample.text_tokens, len(sample.text_tokens) + if self.bos_id is not None: + t = [self.bos_id] + t + tl += 1 + if self.eos_id is not None: + t = t + [self.eos_id] + tl += 1 + + return f, fl, torch.tensor(t).long(), torch.tensor(tl).long() + + def __len__(self): + return len(self.collection) + + +class RpycAudioToTextDataLayer(DataLayerNM): + """Data Layer for general ASR tasks. + + Module which reads ASR labeled data. It accepts comma-separated + JSON manifest files describing the correspondence between wav audio files + and their transcripts. JSON files should be of the following format:: + + {"audio_filepath": path_to_wav_0, "duration": time_in_sec_0, "text": \ +transcript_0} + ... + {"audio_filepath": path_to_wav_n, "duration": time_in_sec_n, "text": \ +transcript_n} + + Args: + manifest_filepath (str): Dataset parameter. + Path to JSON containing data. + labels (list): Dataset parameter. + List of characters that can be output by the ASR model. + For Jasper, this is the 28 character set {a-z '}. The CTC blank + symbol is automatically added later for models using ctc. + batch_size (int): batch size + sample_rate (int): Target sampling rate for data. Audio files will be + resampled to sample_rate if it is not already. + Defaults to 16000. + int_values (bool): Bool indicating whether the audio file is saved as + int data or float data. + Defaults to False. + eos_id (id): Dataset parameter. + End of string symbol id used for seq2seq models. + Defaults to None. + min_duration (float): Dataset parameter. + All training files which have a duration less than min_duration + are dropped. Note: Duration is read from the manifest JSON. + Defaults to 0.1. + max_duration (float): Dataset parameter. + All training files which have a duration more than max_duration + are dropped. Note: Duration is read from the manifest JSON. + Defaults to None. + normalize_transcripts (bool): Dataset parameter. + Whether to use automatic text cleaning. + It is highly recommended to manually clean text for best results. + Defaults to True. + trim_silence (bool): Whether to use trim silence from beginning and end + of audio signal using librosa.effects.trim(). + Defaults to False. + load_audio (bool): Dataset parameter. + Controls whether the dataloader loads the audio signal and + transcript or just the transcript. + Defaults to True. + drop_last (bool): See PyTorch DataLoader. + Defaults to False. + shuffle (bool): See PyTorch DataLoader. + Defaults to True. + num_workers (int): See PyTorch DataLoader. + Defaults to 0. + perturb_config (dict): Currently disabled. + """ + + @property + @add_port_docs() + def output_ports(self): + """Returns definitions of module output ports. + """ + return { + # 'audio_signal': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}), + # 'a_sig_length': NeuralType({0: AxisType(BatchTag)}), + # 'transcripts': NeuralType({0: AxisType(BatchTag), 1: AxisType(TimeTag)}), + # 'transcript_length': NeuralType({0: AxisType(BatchTag)}), + "audio_signal": NeuralType( + ("B", "T"), + AudioSignal(freq=self._sample_rate) + if self is not None and self._sample_rate is not None + else AudioSignal(), + ), + "a_sig_length": NeuralType(tuple("B"), LengthsType()), + "transcripts": NeuralType(("B", "T"), LabelsType()), + "transcript_length": NeuralType(tuple("B"), LengthsType()), + } + + def __init__( + self, + manifest_filepath, + labels, + batch_size, + sample_rate=16000, + int_values=False, + bos_id=None, + eos_id=None, + pad_id=None, + min_duration=0.1, + max_duration=None, + normalize_transcripts=True, + trim_silence=False, + load_audio=True, + rpyc_host="", + drop_last=False, + shuffle=True, + num_workers=0, + ): + super().__init__() + self._sample_rate = sample_rate + + def rpyc_root_fn(): + return rpyc.connect( + rpyc_host, 8064, config={"sync_request_timeout": 600} + ).root + rpyc_conn = rpyc_root_fn() + + self._featurizer = RpycWaveformFeaturizer( + sample_rate=self._sample_rate, + int_values=int_values, + augmentor=None, + rpyc_conn=rpyc_conn, + ) + + def read_remote_manifests(): + local_mp = [] + for mrp in manifest_filepath.split(","): + md = rpyc_conn.read_path(mrp) + mf = tempfile.NamedTemporaryFile( + dir="/tmp", prefix="jasper_manifest.", delete=False + ) + mf.write(md) + mf.close() + local_mp.append(mf.name) + return ",".join(local_mp) + local_manifest_filepath = read_remote_manifests() + dataset_params = { + "manifest_filepath": local_manifest_filepath, + "labels": labels, + "featurizer": self._featurizer, + "max_duration": max_duration, + "min_duration": min_duration, + "normalize": normalize_transcripts, + "trim": trim_silence, + "bos_id": bos_id, + "eos_id": eos_id, + "load_audio": load_audio, + } + + self._dataset = CachedAudioDataset(**dataset_params) + self._batch_size = batch_size + + # Set up data loader + if self._placement == DeviceType.AllGpu: + logging.info("Parallelizing Datalayer.") + sampler = torch.utils.data.distributed.DistributedSampler(self._dataset) + else: + sampler = None + + if batch_size == -1: + batch_size = len(self._dataset) + + pad_id = 0 if pad_id is None else pad_id + self._dataloader = torch.utils.data.DataLoader( + dataset=self._dataset, + batch_size=batch_size, + collate_fn=partial(seq_collate_fn, token_pad_value=pad_id), + drop_last=drop_last, + shuffle=shuffle if sampler is None else False, + sampler=sampler, + num_workers=1, + ) + + def __len__(self): + return len(self._dataset) + + @property + def dataset(self): + return None + + @property + def data_iterator(self): + return self._dataloader diff --git a/jasper/training_utils/featurizer.py b/jasper/training_utils/featurizer.py new file mode 100644 index 0000000..030eb36 --- /dev/null +++ b/jasper/training_utils/featurizer.py @@ -0,0 +1,51 @@ +# import math + +# import librosa +import torch +import pickle +# import torch.nn as nn +# from torch_stft import STFT + +# from nemo import logging +from nemo.collections.asr.parts.perturb import AudioAugmentor +# from nemo.collections.asr.parts.segment import AudioSegment + + +class RpycWaveformFeaturizer(object): + def __init__( + self, sample_rate=16000, int_values=False, augmentor=None, rpyc_conn=None + ): + self.augmentor = augmentor if augmentor is not None else AudioAugmentor() + self.sample_rate = sample_rate + self.int_values = int_values + self.remote_path_samples = rpyc_conn.get_path_samples + + def max_augmentation_length(self, length): + return self.augmentor.max_augmentation_length(length) + + def process(self, file_path, offset=0, duration=0, trim=False): + audio = self.remote_path_samples( + file_path, + target_sr=self.sample_rate, + int_values=self.int_values, + offset=offset, + duration=duration, + trim=trim, + ) + return torch.tensor(pickle.loads(audio), dtype=torch.float) + + def process_segment(self, audio_segment): + self.augmentor.perturb(audio_segment) + return torch.tensor(audio_segment, dtype=torch.float) + + @classmethod + def from_config(cls, input_config, perturbation_configs=None): + if perturbation_configs is not None: + aa = AudioAugmentor.from_config(perturbation_configs) + else: + aa = None + + sample_rate = input_config.get("sample_rate", 16000) + int_values = input_config.get("int_values", False) + + return cls(sample_rate=sample_rate, int_values=int_values, augmentor=aa) diff --git a/jasper/train.py b/jasper/training_utils/train.py similarity index 92% rename from jasper/train.py rename to jasper/training_utils/train.py index def978f..4b4c97f 100644 --- a/jasper/train.py +++ b/jasper/training_utils/train.py @@ -15,7 +15,9 @@ from nemo.collections.asr.helpers import ( process_evaluation_batch, process_evaluation_epoch, ) + from nemo.utils.lr_policies import CosineAnnealing +from .data_loaders import RpycAudioToTextDataLayer logging = nemo.logging @@ -44,7 +46,7 @@ def parse_args(): eval_freq=100, load_dir="./train/models/jasper/", warmup_steps=3, - exp_name='jasper-speller' + exp_name="jasper-speller", ) # Overwrite default args @@ -69,6 +71,14 @@ def parse_args(): help="model configuration file: model.yaml", ) + parser.add_argument( + "--remote_data", + type=str, + required=False, + default="", + help="remote dataloader endpoint", + ) + # Create new args parser.add_argument("--exp_name", default="Jasper", type=str) parser.add_argument("--beta1", default=0.95, type=float) @@ -110,15 +120,23 @@ def create_all_dags(args, neural_factory): # Calculate num_workers for dataloader total_cpus = os.cpu_count() cpu_per_traindl = max(int(total_cpus / neural_factory.world_size), 1) - + # cpu_per_traindl = 1 # perturb_config = jasper_params.get('perturb', None) train_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) train_dl_params.update(jasper_params["AudioToTextDataLayer"]["train"]) del train_dl_params["train"] del train_dl_params["eval"] # del train_dl_params["normalize_transcripts"] - - data_layer = nemo_asr.AudioToTextDataLayer( + data_loader_layer = nemo_asr.AudioToTextDataLayer + if args.remote_data: + train_dl_params['rpyc_host'] = args.remote_data + data_loader_layer = RpycAudioToTextDataLayer + # if args.remote_data: + # # import pdb; pdb.set_trace() + # data_loader_layer = rpyc.connect( + # args.remote_data, 8064, config={"sync_request_timeout": 600} + # ).root.get_data_loader() + data_layer = data_loader_layer( manifest_filepath=args.train_dataset, sample_rate=sample_rate, labels=vocab, @@ -150,13 +168,15 @@ def create_all_dags(args, neural_factory): eval_dl_params = copy.deepcopy(jasper_params["AudioToTextDataLayer"]) eval_dl_params.update(jasper_params["AudioToTextDataLayer"]["eval"]) + if args.remote_data: + eval_dl_params['rpyc_host'] = args.remote_data del eval_dl_params["train"] del eval_dl_params["eval"] data_layers_eval = [] if args.eval_datasets: for eval_datasets in args.eval_datasets: - data_layer_eval = nemo_asr.AudioToTextDataLayer( + data_layer_eval = data_loader_layer( manifest_filepath=eval_datasets, sample_rate=sample_rate, labels=vocab, diff --git a/setup.py b/setup.py index 855d133..6b3e97b 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ setup( "console_scripts": [ "jasper_transcribe = jasper.transcribe:main", "jasper_asr_rpyc_server = jasper.server:main", - "jasper_asr_trainer = jasper.train:main", + "jasper_asr_trainer = jasper.training_utils.train:main", "jasper_asr_data_generate = jasper.data_utils.generator:main", "jasper_asr_data_recycle = jasper.data_utils.call_recycler:main", "jasper_asr_data_validation = jasper.data_utils.validation.process:main",