mirror of https://github.com/malarinv/tacotron2
distributed.py: rewrite
parent
1683a57ae5
commit
d0aa9e7d32
|
|
@ -118,3 +118,55 @@ class DistributedDataParallel(Module):
|
||||||
super(DistributedDataParallel, self).train(mode)
|
super(DistributedDataParallel, self).train(mode)
|
||||||
self.module.train(mode)
|
self.module.train(mode)
|
||||||
'''
|
'''
|
||||||
|
'''
|
||||||
|
Modifies existing model to do gradient allreduce, but doesn't change class
|
||||||
|
so you don't need "module"
|
||||||
|
'''
|
||||||
|
def apply_gradient_allreduce(module):
|
||||||
|
if not hasattr(dist, '_backend'):
|
||||||
|
module.warn_on_half = True
|
||||||
|
else:
|
||||||
|
module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
|
||||||
|
|
||||||
|
for p in module.state_dict().values():
|
||||||
|
if not torch.is_tensor(p):
|
||||||
|
continue
|
||||||
|
dist.broadcast(p, 0)
|
||||||
|
|
||||||
|
def allreduce_params():
|
||||||
|
if(module.needs_reduction):
|
||||||
|
module.needs_reduction = False
|
||||||
|
buckets = {}
|
||||||
|
for param in module.parameters():
|
||||||
|
if param.requires_grad and param.grad is not None:
|
||||||
|
tp = type(param.data)
|
||||||
|
if tp not in buckets:
|
||||||
|
buckets[tp] = []
|
||||||
|
buckets[tp].append(param)
|
||||||
|
if module.warn_on_half:
|
||||||
|
if torch.cuda.HalfTensor in buckets:
|
||||||
|
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
|
||||||
|
" It is recommended to use the NCCL backend in this case. This currently requires" +
|
||||||
|
"PyTorch built from top of tree master.")
|
||||||
|
module.warn_on_half = False
|
||||||
|
|
||||||
|
for tp in buckets:
|
||||||
|
bucket = buckets[tp]
|
||||||
|
grads = [param.grad.data for param in bucket]
|
||||||
|
coalesced = _flatten_dense_tensors(grads)
|
||||||
|
dist.all_reduce(coalesced)
|
||||||
|
coalesced /= dist.get_world_size()
|
||||||
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||||
|
buf.copy_(synced)
|
||||||
|
|
||||||
|
for param in list(module.parameters()):
|
||||||
|
def allreduce_hook(*unused):
|
||||||
|
param._execution_engine.queue_callback(allreduce_params)
|
||||||
|
if param.requires_grad:
|
||||||
|
param.register_hook(allreduce_hook)
|
||||||
|
|
||||||
|
def set_needs_reduction(self, input, output):
|
||||||
|
self.needs_reduction = True
|
||||||
|
|
||||||
|
module.register_forward_hook(set_needs_reduction)
|
||||||
|
return module
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue