mirror of
https://github.com/malarinv/tacotron2
synced 2026-03-08 01:32:35 +00:00
distributed.py: replacing to avoid distributed error
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.nn.modules import Module
|
from torch.nn.modules import Module
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
def _flatten_dense_tensors(tensors):
|
def _flatten_dense_tensors(tensors):
|
||||||
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
|
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
|
||||||
@@ -161,12 +162,12 @@ def apply_gradient_allreduce(module):
|
|||||||
|
|
||||||
for param in list(module.parameters()):
|
for param in list(module.parameters()):
|
||||||
def allreduce_hook(*unused):
|
def allreduce_hook(*unused):
|
||||||
param._execution_engine.queue_callback(allreduce_params)
|
Variable._execution_engine.queue_callback(allreduce_params)
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
param.register_hook(allreduce_hook)
|
param.register_hook(allreduce_hook)
|
||||||
|
|
||||||
def set_needs_reduction(self, input, output):
|
def set_needs_reduction(self, input, output):
|
||||||
self.needs_reduction = True
|
self.needs_reduction = True
|
||||||
|
|
||||||
module.register_forward_hook(set_needs_reduction)
|
module.register_forward_hook(set_needs_reduction)
|
||||||
return module
|
return module
|
||||||
|
|||||||
Reference in New Issue
Block a user