mirror of
https://github.com/malarinv/tacotron2
synced 2026-03-08 01:32:35 +00:00
loss_scaler.py: patching loss scaler for compatibility with current pytorch
This commit is contained in:
@@ -51,11 +51,10 @@ class DynamicLossScaler:
|
|||||||
|
|
||||||
# `x` is a torch.Tensor
|
# `x` is a torch.Tensor
|
||||||
def _has_inf_or_nan(x):
|
def _has_inf_or_nan(x):
|
||||||
inf_count = torch.sum(x.abs() == float('inf'))
|
cpu_sum = float(x.float().sum())
|
||||||
if inf_count > 0:
|
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
|
||||||
return True
|
return True
|
||||||
nan_count = torch.sum(x != x)
|
return False
|
||||||
return nan_count > 0
|
|
||||||
|
|
||||||
# `overflow` is boolean indicating whether we overflowed in gradient
|
# `overflow` is boolean indicating whether we overflowed in gradient
|
||||||
def update_scale(self, overflow):
|
def update_scale(self, overflow):
|
||||||
|
|||||||
Reference in New Issue
Block a user