1
0
mirror of https://github.com/malarinv/tacotron2 synced 2026-03-08 01:32:35 +00:00

train.py: renaming function, removing dataparallel

This commit is contained in:
rafaelvalle
2018-11-27 18:04:12 -08:00
parent 3045ba125b
commit f06063f746

View File

@@ -84,9 +84,7 @@ def load_model(hparams):
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
if hparams.distributed_run:
model = DistributedDataParallel(model)
elif torch.cuda.device_count() > 1:
model = DataParallel(model)
model = apply_gradient_allreduce(model)
return model