mirror of https://github.com/malarinv/tacotron2
Merge branch 'master' of https://github.com/NVIDIA/tacotron2
commit
eb2a171690
10
train.py
10
train.py
|
|
@ -45,10 +45,14 @@ def prepare_dataloaders(hparams):
|
||||||
valset = TextMelLoader(hparams.validation_files, hparams)
|
valset = TextMelLoader(hparams.validation_files, hparams)
|
||||||
collate_fn = TextMelCollate(hparams.n_frames_per_step)
|
collate_fn = TextMelCollate(hparams.n_frames_per_step)
|
||||||
|
|
||||||
train_sampler = DistributedSampler(trainset) \
|
if hparams.distributed_run:
|
||||||
if hparams.distributed_run else None
|
train_sampler = DistributedSampler(trainset)
|
||||||
|
shuffle = False
|
||||||
|
else:
|
||||||
|
train_sampler = None
|
||||||
|
shuffle = True
|
||||||
|
|
||||||
train_loader = DataLoader(trainset, num_workers=1, shuffle=True,
|
train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
batch_size=hparams.batch_size, pin_memory=False,
|
batch_size=hparams.batch_size, pin_memory=False,
|
||||||
drop_last=True, collate_fn=collate_fn)
|
drop_last=True, collate_fn=collate_fn)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue