From 3a5ce069ab9cee4661ec7f8bd27ae578385ad8bf Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Fri, 29 May 2020 12:14:14 +0530 Subject: [PATCH] parallelize data loading from remote --- jasper/training/cli.py | 5 +--- jasper/training/data_loaders.py | 42 ++++++++++++++++++++++++++------- setup.py | 2 +- 3 files changed, 36 insertions(+), 13 deletions(-) diff --git a/jasper/training/cli.py b/jasper/training/cli.py index 15f13fc..1f628bd 100644 --- a/jasper/training/cli.py +++ b/jasper/training/cli.py @@ -37,9 +37,7 @@ def parse_args(): lr=0.002, amp_opt_level="O1", create_tb_writer=True, - model_config="./train/jasper-speller10x5dr.yaml", - # train_dataset="./train/asr_data/train_manifest.json", - # eval_datasets="./train/asr_data/test_manifest.json", + model_config="./train/jasper10x5dr.yaml", work_dir="./train/work", num_epochs=300, weight_decay=0.005, @@ -61,7 +59,6 @@ def parse_args(): parser.add_argument( "--num_epochs", type=int, - default=None, required=False, help="number of epochs to train", ) diff --git a/jasper/training/data_loaders.py b/jasper/training/data_loaders.py index 5ffa3e4..d181dfa 100644 --- a/jasper/training/data_loaders.py +++ b/jasper/training/data_loaders.py @@ -26,6 +26,8 @@ from nemo.collections.asr.parts.dataset import ( # from functools import lru_cache import rpyc +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm from .featurizer import RpycWaveformFeaturizer # from nemo.collections.asr.parts.features import WaveformFeaturizer @@ -82,12 +84,16 @@ class CachedAudioDataset(torch.utils.data.Dataset): bos_id=None, eos_id=None, load_audio=True, - parser='en', + parser="en", ): self.collection = collections.ASRAudioText( - manifests_files=manifest_filepath.split(','), + manifests_files=manifest_filepath.split(","), parser=parsers.make_parser( - labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize, + labels=labels, + name=parser, + unk_id=unk_index, + blank_id=blank_index, + do_normalize=normalize, ), min_duration=min_duration, max_duration=max_duration, @@ -100,10 +106,23 @@ class CachedAudioDataset(torch.utils.data.Dataset): self.eos_id = eos_id self.bos_id = bos_id self.load_audio = load_audio - print(f'initializing dataset {manifest_filepath}') - for i in range(len(self.collection)): - self[i] - print(f'initializing complete') + print(f"initializing dataset {manifest_filepath}") + + def exec_func(i): + return self[i] + + task_count = len(self.collection) + with ThreadPoolExecutor() as exe: + print("starting all loading tasks") + list( + tqdm( + exe.map(exec_func, range(task_count)), + position=0, + leave=True, + total=task_count, + ) + ) + print(f"initializing complete") def __getitem__(self, index): sample = self.collection[index] @@ -112,7 +131,12 @@ class CachedAudioDataset(torch.utils.data.Dataset): if cached_features is not None: features = cached_features else: - features = self.featurizer.process(sample.audio_file, offset=0, duration=sample.duration, trim=self.trim,) + features = self.featurizer.process( + sample.audio_file, + offset=0, + duration=sample.duration, + trim=self.trim, + ) self.index_feature_map[index] = features f, fl = features, torch.tensor(features.shape[0]).long() else: @@ -238,6 +262,7 @@ transcript_n} return rpyc.connect( rpyc_host, 8064, config={"sync_request_timeout": 600} ).root + rpyc_conn = rpyc_root_fn() self._featurizer = RpycWaveformFeaturizer( @@ -258,6 +283,7 @@ transcript_n} mf.close() local_mp.append(mf.name) return ",".join(local_mp) + local_manifest_filepath = read_remote_manifests() dataset_params = { "manifest_filepath": local_manifest_filepath, diff --git a/setup.py b/setup.py index 38f253e..87cf7e9 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ requirements = [ ] extra_requirements = { - "server": ["rpyc~=4.1.4"], + "server": ["rpyc~=4.1.4", "tqdm~=4.39.0"], "data": [ "google-cloud-texttospeech~=1.0.1", "tqdm~=4.39.0",