1
0
mirror of https://github.com/malarinv/tacotron2 synced 2026-03-08 01:32:35 +00:00

distributed.py: replacing to avoid distributed error

This commit is contained in:
rafaelvalle
2018-11-27 21:01:26 -08:00
parent 0ad65cc053
commit 52a30bb7b6

View File

@@ -1,6 +1,7 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.nn.modules import Module from torch.nn.modules import Module
from torch.autograd import Variable
def _flatten_dense_tensors(tensors): def _flatten_dense_tensors(tensors):
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
@@ -161,12 +162,12 @@ def apply_gradient_allreduce(module):
for param in list(module.parameters()): for param in list(module.parameters()):
def allreduce_hook(*unused): def allreduce_hook(*unused):
param._execution_engine.queue_callback(allreduce_params) Variable._execution_engine.queue_callback(allreduce_params)
if param.requires_grad: if param.requires_grad:
param.register_hook(allreduce_hook) param.register_hook(allreduce_hook)
def set_needs_reduction(self, input, output): def set_needs_reduction(self, input, output):
self.needs_reduction = True self.needs_reduction = True
module.register_forward_hook(set_needs_reduction) module.register_forward_hook(set_needs_reduction)
return module return module