parallelize data loading from remote
parent
9f9cb62b60
commit
3a5ce069ab
|
|
@ -37,9 +37,7 @@ def parse_args():
|
|||
lr=0.002,
|
||||
amp_opt_level="O1",
|
||||
create_tb_writer=True,
|
||||
model_config="./train/jasper-speller10x5dr.yaml",
|
||||
# train_dataset="./train/asr_data/train_manifest.json",
|
||||
# eval_datasets="./train/asr_data/test_manifest.json",
|
||||
model_config="./train/jasper10x5dr.yaml",
|
||||
work_dir="./train/work",
|
||||
num_epochs=300,
|
||||
weight_decay=0.005,
|
||||
|
|
@ -61,7 +59,6 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help="number of epochs to train",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ from nemo.collections.asr.parts.dataset import (
|
|||
|
||||
# from functools import lru_cache
|
||||
import rpyc
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
from .featurizer import RpycWaveformFeaturizer
|
||||
|
||||
# from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||
|
|
@ -82,12 +84,16 @@ class CachedAudioDataset(torch.utils.data.Dataset):
|
|||
bos_id=None,
|
||||
eos_id=None,
|
||||
load_audio=True,
|
||||
parser='en',
|
||||
parser="en",
|
||||
):
|
||||
self.collection = collections.ASRAudioText(
|
||||
manifests_files=manifest_filepath.split(','),
|
||||
manifests_files=manifest_filepath.split(","),
|
||||
parser=parsers.make_parser(
|
||||
labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize,
|
||||
labels=labels,
|
||||
name=parser,
|
||||
unk_id=unk_index,
|
||||
blank_id=blank_index,
|
||||
do_normalize=normalize,
|
||||
),
|
||||
min_duration=min_duration,
|
||||
max_duration=max_duration,
|
||||
|
|
@ -100,10 +106,23 @@ class CachedAudioDataset(torch.utils.data.Dataset):
|
|||
self.eos_id = eos_id
|
||||
self.bos_id = bos_id
|
||||
self.load_audio = load_audio
|
||||
print(f'initializing dataset {manifest_filepath}')
|
||||
for i in range(len(self.collection)):
|
||||
self[i]
|
||||
print(f'initializing complete')
|
||||
print(f"initializing dataset {manifest_filepath}")
|
||||
|
||||
def exec_func(i):
|
||||
return self[i]
|
||||
|
||||
task_count = len(self.collection)
|
||||
with ThreadPoolExecutor() as exe:
|
||||
print("starting all loading tasks")
|
||||
list(
|
||||
tqdm(
|
||||
exe.map(exec_func, range(task_count)),
|
||||
position=0,
|
||||
leave=True,
|
||||
total=task_count,
|
||||
)
|
||||
)
|
||||
print(f"initializing complete")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.collection[index]
|
||||
|
|
@ -112,7 +131,12 @@ class CachedAudioDataset(torch.utils.data.Dataset):
|
|||
if cached_features is not None:
|
||||
features = cached_features
|
||||
else:
|
||||
features = self.featurizer.process(sample.audio_file, offset=0, duration=sample.duration, trim=self.trim,)
|
||||
features = self.featurizer.process(
|
||||
sample.audio_file,
|
||||
offset=0,
|
||||
duration=sample.duration,
|
||||
trim=self.trim,
|
||||
)
|
||||
self.index_feature_map[index] = features
|
||||
f, fl = features, torch.tensor(features.shape[0]).long()
|
||||
else:
|
||||
|
|
@ -238,6 +262,7 @@ transcript_n}
|
|||
return rpyc.connect(
|
||||
rpyc_host, 8064, config={"sync_request_timeout": 600}
|
||||
).root
|
||||
|
||||
rpyc_conn = rpyc_root_fn()
|
||||
|
||||
self._featurizer = RpycWaveformFeaturizer(
|
||||
|
|
@ -258,6 +283,7 @@ transcript_n}
|
|||
mf.close()
|
||||
local_mp.append(mf.name)
|
||||
return ",".join(local_mp)
|
||||
|
||||
local_manifest_filepath = read_remote_manifests()
|
||||
dataset_params = {
|
||||
"manifest_filepath": local_manifest_filepath,
|
||||
|
|
|
|||
Loading…
Reference in New Issue