loss_scaler.py: patching loss scaler for compatibility with current pytorch

load_mel_from_disk
Rafael Valle 2018-05-15 09:50:08 -07:00
parent 2da7a2ebab
commit cd851585cb
1 changed files with 3 additions and 4 deletions

View File

@ -51,11 +51,10 @@ class DynamicLossScaler:
# `x` is a torch.Tensor # `x` is a torch.Tensor
def _has_inf_or_nan(x): def _has_inf_or_nan(x):
inf_count = torch.sum(x.abs() == float('inf')) cpu_sum = float(x.float().sum())
if inf_count > 0: if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True return True
nan_count = torch.sum(x != x) return False
return nan_count > 0
# `overflow` is boolean indicating whether we overflowed in gradient # `overflow` is boolean indicating whether we overflowed in gradient
def update_scale(self, overflow): def update_scale(self, overflow):