mirror of
https://github.com/malarinv/tacotron2
synced 2026-03-08 09:42:34 +00:00
train.py: renaming function, removing dataparallel
This commit is contained in:
4
train.py
4
train.py
@@ -84,9 +84,7 @@ def load_model(hparams):
|
|||||||
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
|
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
|
||||||
|
|
||||||
if hparams.distributed_run:
|
if hparams.distributed_run:
|
||||||
model = DistributedDataParallel(model)
|
model = apply_gradient_allreduce(model)
|
||||||
elif torch.cuda.device_count() > 1:
|
|
||||||
model = DataParallel(model)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user