parallelize data loading from remote

Malar Kannan 2020-05-29 12:14:14 +05:30
parent 9f9cb62b60
commit 3a5ce069ab
3 changed files with 36 additions and 13 deletions

View File

@ -37,9 +37,7 @@ def parse_args():
lr=0.002, lr=0.002,
amp_opt_level="O1", amp_opt_level="O1",
create_tb_writer=True, create_tb_writer=True,
model_config="./train/jasper-speller10x5dr.yaml", model_config="./train/jasper10x5dr.yaml",
# train_dataset="./train/asr_data/train_manifest.json",
# eval_datasets="./train/asr_data/test_manifest.json",
work_dir="./train/work", work_dir="./train/work",
num_epochs=300, num_epochs=300,
weight_decay=0.005, weight_decay=0.005,
@ -61,7 +59,6 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--num_epochs", "--num_epochs",
type=int, type=int,
default=None,
required=False, required=False,
help="number of epochs to train", help="number of epochs to train",
) )

View File

@ -26,6 +26,8 @@ from nemo.collections.asr.parts.dataset import (
# from functools import lru_cache # from functools import lru_cache
import rpyc import rpyc
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from .featurizer import RpycWaveformFeaturizer from .featurizer import RpycWaveformFeaturizer
# from nemo.collections.asr.parts.features import WaveformFeaturizer # from nemo.collections.asr.parts.features import WaveformFeaturizer
@ -82,12 +84,16 @@ class CachedAudioDataset(torch.utils.data.Dataset):
bos_id=None, bos_id=None,
eos_id=None, eos_id=None,
load_audio=True, load_audio=True,
parser='en', parser="en",
): ):
self.collection = collections.ASRAudioText( self.collection = collections.ASRAudioText(
manifests_files=manifest_filepath.split(','), manifests_files=manifest_filepath.split(","),
parser=parsers.make_parser( parser=parsers.make_parser(
labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize, labels=labels,
name=parser,
unk_id=unk_index,
blank_id=blank_index,
do_normalize=normalize,
), ),
min_duration=min_duration, min_duration=min_duration,
max_duration=max_duration, max_duration=max_duration,
@ -100,10 +106,23 @@ class CachedAudioDataset(torch.utils.data.Dataset):
self.eos_id = eos_id self.eos_id = eos_id
self.bos_id = bos_id self.bos_id = bos_id
self.load_audio = load_audio self.load_audio = load_audio
print(f'initializing dataset {manifest_filepath}') print(f"initializing dataset {manifest_filepath}")
for i in range(len(self.collection)):
self[i] def exec_func(i):
print(f'initializing complete') return self[i]
task_count = len(self.collection)
with ThreadPoolExecutor() as exe:
print("starting all loading tasks")
list(
tqdm(
exe.map(exec_func, range(task_count)),
position=0,
leave=True,
total=task_count,
)
)
print(f"initializing complete")
def __getitem__(self, index): def __getitem__(self, index):
sample = self.collection[index] sample = self.collection[index]
@ -112,7 +131,12 @@ class CachedAudioDataset(torch.utils.data.Dataset):
if cached_features is not None: if cached_features is not None:
features = cached_features features = cached_features
else: else:
features = self.featurizer.process(sample.audio_file, offset=0, duration=sample.duration, trim=self.trim,) features = self.featurizer.process(
sample.audio_file,
offset=0,
duration=sample.duration,
trim=self.trim,
)
self.index_feature_map[index] = features self.index_feature_map[index] = features
f, fl = features, torch.tensor(features.shape[0]).long() f, fl = features, torch.tensor(features.shape[0]).long()
else: else:
@ -238,6 +262,7 @@ transcript_n}
return rpyc.connect( return rpyc.connect(
rpyc_host, 8064, config={"sync_request_timeout": 600} rpyc_host, 8064, config={"sync_request_timeout": 600}
).root ).root
rpyc_conn = rpyc_root_fn() rpyc_conn = rpyc_root_fn()
self._featurizer = RpycWaveformFeaturizer( self._featurizer = RpycWaveformFeaturizer(
@ -258,6 +283,7 @@ transcript_n}
mf.close() mf.close()
local_mp.append(mf.name) local_mp.append(mf.name)
return ",".join(local_mp) return ",".join(local_mp)
local_manifest_filepath = read_remote_manifests() local_manifest_filepath = read_remote_manifests()
dataset_params = { dataset_params = {
"manifest_filepath": local_manifest_filepath, "manifest_filepath": local_manifest_filepath,

View File

@ -8,7 +8,7 @@ requirements = [
] ]
extra_requirements = { extra_requirements = {
"server": ["rpyc~=4.1.4"], "server": ["rpyc~=4.1.4", "tqdm~=4.39.0"],
"data": [ "data": [
"google-cloud-texttospeech~=1.0.1", "google-cloud-texttospeech~=1.0.1",
"tqdm~=4.39.0", "tqdm~=4.39.0",