parallelize data loading from remote
parent
9f9cb62b60
commit
3a5ce069ab
|
|
@ -37,9 +37,7 @@ def parse_args():
|
||||||
lr=0.002,
|
lr=0.002,
|
||||||
amp_opt_level="O1",
|
amp_opt_level="O1",
|
||||||
create_tb_writer=True,
|
create_tb_writer=True,
|
||||||
model_config="./train/jasper-speller10x5dr.yaml",
|
model_config="./train/jasper10x5dr.yaml",
|
||||||
# train_dataset="./train/asr_data/train_manifest.json",
|
|
||||||
# eval_datasets="./train/asr_data/test_manifest.json",
|
|
||||||
work_dir="./train/work",
|
work_dir="./train/work",
|
||||||
num_epochs=300,
|
num_epochs=300,
|
||||||
weight_decay=0.005,
|
weight_decay=0.005,
|
||||||
|
|
@ -61,7 +59,6 @@ def parse_args():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_epochs",
|
"--num_epochs",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
|
||||||
required=False,
|
required=False,
|
||||||
help="number of epochs to train",
|
help="number of epochs to train",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@ from nemo.collections.asr.parts.dataset import (
|
||||||
|
|
||||||
# from functools import lru_cache
|
# from functools import lru_cache
|
||||||
import rpyc
|
import rpyc
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from tqdm import tqdm
|
||||||
from .featurizer import RpycWaveformFeaturizer
|
from .featurizer import RpycWaveformFeaturizer
|
||||||
|
|
||||||
# from nemo.collections.asr.parts.features import WaveformFeaturizer
|
# from nemo.collections.asr.parts.features import WaveformFeaturizer
|
||||||
|
|
@ -82,12 +84,16 @@ class CachedAudioDataset(torch.utils.data.Dataset):
|
||||||
bos_id=None,
|
bos_id=None,
|
||||||
eos_id=None,
|
eos_id=None,
|
||||||
load_audio=True,
|
load_audio=True,
|
||||||
parser='en',
|
parser="en",
|
||||||
):
|
):
|
||||||
self.collection = collections.ASRAudioText(
|
self.collection = collections.ASRAudioText(
|
||||||
manifests_files=manifest_filepath.split(','),
|
manifests_files=manifest_filepath.split(","),
|
||||||
parser=parsers.make_parser(
|
parser=parsers.make_parser(
|
||||||
labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize,
|
labels=labels,
|
||||||
|
name=parser,
|
||||||
|
unk_id=unk_index,
|
||||||
|
blank_id=blank_index,
|
||||||
|
do_normalize=normalize,
|
||||||
),
|
),
|
||||||
min_duration=min_duration,
|
min_duration=min_duration,
|
||||||
max_duration=max_duration,
|
max_duration=max_duration,
|
||||||
|
|
@ -100,10 +106,23 @@ class CachedAudioDataset(torch.utils.data.Dataset):
|
||||||
self.eos_id = eos_id
|
self.eos_id = eos_id
|
||||||
self.bos_id = bos_id
|
self.bos_id = bos_id
|
||||||
self.load_audio = load_audio
|
self.load_audio = load_audio
|
||||||
print(f'initializing dataset {manifest_filepath}')
|
print(f"initializing dataset {manifest_filepath}")
|
||||||
for i in range(len(self.collection)):
|
|
||||||
self[i]
|
def exec_func(i):
|
||||||
print(f'initializing complete')
|
return self[i]
|
||||||
|
|
||||||
|
task_count = len(self.collection)
|
||||||
|
with ThreadPoolExecutor() as exe:
|
||||||
|
print("starting all loading tasks")
|
||||||
|
list(
|
||||||
|
tqdm(
|
||||||
|
exe.map(exec_func, range(task_count)),
|
||||||
|
position=0,
|
||||||
|
leave=True,
|
||||||
|
total=task_count,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(f"initializing complete")
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
sample = self.collection[index]
|
sample = self.collection[index]
|
||||||
|
|
@ -112,7 +131,12 @@ class CachedAudioDataset(torch.utils.data.Dataset):
|
||||||
if cached_features is not None:
|
if cached_features is not None:
|
||||||
features = cached_features
|
features = cached_features
|
||||||
else:
|
else:
|
||||||
features = self.featurizer.process(sample.audio_file, offset=0, duration=sample.duration, trim=self.trim,)
|
features = self.featurizer.process(
|
||||||
|
sample.audio_file,
|
||||||
|
offset=0,
|
||||||
|
duration=sample.duration,
|
||||||
|
trim=self.trim,
|
||||||
|
)
|
||||||
self.index_feature_map[index] = features
|
self.index_feature_map[index] = features
|
||||||
f, fl = features, torch.tensor(features.shape[0]).long()
|
f, fl = features, torch.tensor(features.shape[0]).long()
|
||||||
else:
|
else:
|
||||||
|
|
@ -238,6 +262,7 @@ transcript_n}
|
||||||
return rpyc.connect(
|
return rpyc.connect(
|
||||||
rpyc_host, 8064, config={"sync_request_timeout": 600}
|
rpyc_host, 8064, config={"sync_request_timeout": 600}
|
||||||
).root
|
).root
|
||||||
|
|
||||||
rpyc_conn = rpyc_root_fn()
|
rpyc_conn = rpyc_root_fn()
|
||||||
|
|
||||||
self._featurizer = RpycWaveformFeaturizer(
|
self._featurizer = RpycWaveformFeaturizer(
|
||||||
|
|
@ -258,6 +283,7 @@ transcript_n}
|
||||||
mf.close()
|
mf.close()
|
||||||
local_mp.append(mf.name)
|
local_mp.append(mf.name)
|
||||||
return ",".join(local_mp)
|
return ",".join(local_mp)
|
||||||
|
|
||||||
local_manifest_filepath = read_remote_manifests()
|
local_manifest_filepath = read_remote_manifests()
|
||||||
dataset_params = {
|
dataset_params = {
|
||||||
"manifest_filepath": local_manifest_filepath,
|
"manifest_filepath": local_manifest_filepath,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue