parallelize data loading from remote

Malar Kannan 2020-05-29 12:14:14 +05:30
parent 9f9cb62b60
commit 3a5ce069ab
3 changed files with 36 additions and 13 deletions

View File

@ -37,9 +37,7 @@ def parse_args():
lr=0.002,
amp_opt_level="O1",
create_tb_writer=True,
model_config="./train/jasper-speller10x5dr.yaml",
# train_dataset="./train/asr_data/train_manifest.json",
# eval_datasets="./train/asr_data/test_manifest.json",
model_config="./train/jasper10x5dr.yaml",
work_dir="./train/work",
num_epochs=300,
weight_decay=0.005,
@ -61,7 +59,6 @@ def parse_args():
parser.add_argument(
"--num_epochs",
type=int,
default=None,
required=False,
help="number of epochs to train",
)

View File

@ -26,6 +26,8 @@ from nemo.collections.asr.parts.dataset import (
# from functools import lru_cache
import rpyc
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from .featurizer import RpycWaveformFeaturizer
# from nemo.collections.asr.parts.features import WaveformFeaturizer
@ -82,12 +84,16 @@ class CachedAudioDataset(torch.utils.data.Dataset):
bos_id=None,
eos_id=None,
load_audio=True,
parser='en',
parser="en",
):
self.collection = collections.ASRAudioText(
manifests_files=manifest_filepath.split(','),
manifests_files=manifest_filepath.split(","),
parser=parsers.make_parser(
labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize,
labels=labels,
name=parser,
unk_id=unk_index,
blank_id=blank_index,
do_normalize=normalize,
),
min_duration=min_duration,
max_duration=max_duration,
@ -100,10 +106,23 @@ class CachedAudioDataset(torch.utils.data.Dataset):
self.eos_id = eos_id
self.bos_id = bos_id
self.load_audio = load_audio
print(f'initializing dataset {manifest_filepath}')
for i in range(len(self.collection)):
self[i]
print(f'initializing complete')
print(f"initializing dataset {manifest_filepath}")
def exec_func(i):
return self[i]
task_count = len(self.collection)
with ThreadPoolExecutor() as exe:
print("starting all loading tasks")
list(
tqdm(
exe.map(exec_func, range(task_count)),
position=0,
leave=True,
total=task_count,
)
)
print(f"initializing complete")
def __getitem__(self, index):
sample = self.collection[index]
@ -112,7 +131,12 @@ class CachedAudioDataset(torch.utils.data.Dataset):
if cached_features is not None:
features = cached_features
else:
features = self.featurizer.process(sample.audio_file, offset=0, duration=sample.duration, trim=self.trim,)
features = self.featurizer.process(
sample.audio_file,
offset=0,
duration=sample.duration,
trim=self.trim,
)
self.index_feature_map[index] = features
f, fl = features, torch.tensor(features.shape[0]).long()
else:
@ -238,6 +262,7 @@ transcript_n}
return rpyc.connect(
rpyc_host, 8064, config={"sync_request_timeout": 600}
).root
rpyc_conn = rpyc_root_fn()
self._featurizer = RpycWaveformFeaturizer(
@ -258,6 +283,7 @@ transcript_n}
mf.close()
local_mp.append(mf.name)
return ",".join(local_mp)
local_manifest_filepath = read_remote_manifests()
dataset_params = {
"manifest_filepath": local_manifest_filepath,

View File

@ -8,7 +8,7 @@ requirements = [
]
extra_requirements = {
"server": ["rpyc~=4.1.4"],
"server": ["rpyc~=4.1.4", "tqdm~=4.39.0"],
"data": [
"google-cloud-texttospeech~=1.0.1",
"tqdm~=4.39.0",