mirror of https://github.com/malarinv/tacotron2
train.py: renaming function, removing dataparallel
parent
3045ba125b
commit
f06063f746
4
train.py
4
train.py
|
|
@ -84,9 +84,7 @@ def load_model(hparams):
|
|||
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
|
||||
|
||||
if hparams.distributed_run:
|
||||
model = DistributedDataParallel(model)
|
||||
elif torch.cuda.device_count() > 1:
|
||||
model = DataParallel(model)
|
||||
model = apply_gradient_allreduce(model)
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue