mirror of https://github.com/malarinv/tacotron2
distributed.py: replacing to avoid distributed error
parent
0ad65cc053
commit
52a30bb7b6
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.modules import Module
|
||||
from torch.autograd import Variable
|
||||
|
||||
def _flatten_dense_tensors(tensors):
|
||||
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
|
||||
|
|
@ -161,7 +162,7 @@ def apply_gradient_allreduce(module):
|
|||
|
||||
for param in list(module.parameters()):
|
||||
def allreduce_hook(*unused):
|
||||
param._execution_engine.queue_callback(allreduce_params)
|
||||
Variable._execution_engine.queue_callback(allreduce_params)
|
||||
if param.requires_grad:
|
||||
param.register_hook(allreduce_hook)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue