Fixing concatenation error for fp16 ditributed training

experiments
gkarch 2019-02-01 09:55:59 +01:00
parent 825ffa47d1
commit df4a466af2
1 changed files with 1 additions and 1 deletions

View File

@ -140,7 +140,7 @@ def apply_gradient_allreduce(module):
buckets = {}
for param in module.parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
tp = param.data.dtype
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)