mirror of
https://github.com/malarinv/tacotron2
synced 2026-03-08 09:42:34 +00:00
Merge pull request #136 from GrzegorzKarchNV/master
Fixing concatenation error for fp16 distributed training
This commit is contained in:
@@ -140,7 +140,7 @@ def apply_gradient_allreduce(module):
|
|||||||
buckets = {}
|
buckets = {}
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
if param.requires_grad and param.grad is not None:
|
if param.requires_grad and param.grad is not None:
|
||||||
tp = type(param.data)
|
tp = param.data.dtype
|
||||||
if tp not in buckets:
|
if tp not in buckets:
|
||||||
buckets[tp] = []
|
buckets[tp] = []
|
||||||
buckets[tp].append(param)
|
buckets[tp].append(param)
|
||||||
|
|||||||
Reference in New Issue
Block a user