mirror of https://github.com/malarinv/tacotron2
commit
da30fd8709
|
|
@ -51,11 +51,10 @@ class DynamicLossScaler:
|
||||||
|
|
||||||
# `x` is a torch.Tensor
|
# `x` is a torch.Tensor
|
||||||
def _has_inf_or_nan(x):
|
def _has_inf_or_nan(x):
|
||||||
inf_count = torch.sum(x.abs() == float('inf'))
|
cpu_sum = float(x.float().sum())
|
||||||
if inf_count > 0:
|
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
|
||||||
return True
|
return True
|
||||||
nan_count = torch.sum(x != x)
|
return False
|
||||||
return nan_count > 0
|
|
||||||
|
|
||||||
# `overflow` is boolean indicating whether we overflowed in gradient
|
# `overflow` is boolean indicating whether we overflowed in gradient
|
||||||
def update_scale(self, overflow):
|
def update_scale(self, overflow):
|
||||||
|
|
|
||||||
7
train.py
7
train.py
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
import math
|
import math
|
||||||
|
from numpy import finfo
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from distributed import DistributedDataParallel
|
from distributed import DistributedDataParallel
|
||||||
|
|
@ -77,7 +78,9 @@ def prepare_directories_and_logger(output_directory, log_directory, rank):
|
||||||
|
|
||||||
def load_model(hparams):
|
def load_model(hparams):
|
||||||
model = Tacotron2(hparams).cuda()
|
model = Tacotron2(hparams).cuda()
|
||||||
model = batchnorm_to_float(model.half()) if hparams.fp16_run else model
|
if hparams.fp16_run:
|
||||||
|
model = batchnorm_to_float(model.half())
|
||||||
|
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
|
||||||
|
|
||||||
if hparams.distributed_run:
|
if hparams.distributed_run:
|
||||||
model = DistributedDataParallel(model)
|
model = DistributedDataParallel(model)
|
||||||
|
|
@ -276,7 +279,7 @@ if __name__ == '__main__':
|
||||||
torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
|
torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
|
||||||
|
|
||||||
print("FP16 Run:", hparams.fp16_run)
|
print("FP16 Run:", hparams.fp16_run)
|
||||||
print("Dynamic Loss Scaling", hparams.dynamic_loss_scaling)
|
print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling)
|
||||||
print("Distributed Run:", hparams.distributed_run)
|
print("Distributed Run:", hparams.distributed_run)
|
||||||
print("cuDNN Enabled:", hparams.cudnn_enabled)
|
print("cuDNN Enabled:", hparams.cudnn_enabled)
|
||||||
print("cuDNN Benchmark:", hparams.cudnn_benchmark)
|
print("cuDNN Benchmark:", hparams.cudnn_benchmark)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue