mirror of https://github.com/malarinv/tacotron2
train.py: using amp for mixed precision training
parent
bb20035586
commit
0274619e45
38
train.py
38
train.py
|
|
@ -10,8 +10,6 @@ import torch.distributed as dist
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from fp16_optimizer import FP16_Optimizer
|
|
||||||
|
|
||||||
from model import Tacotron2
|
from model import Tacotron2
|
||||||
from data_utils import TextMelLoader, TextMelCollate
|
from data_utils import TextMelLoader, TextMelCollate
|
||||||
from loss_function import Tacotron2Loss
|
from loss_function import Tacotron2Loss
|
||||||
|
|
@ -19,15 +17,6 @@ from logger import Tacotron2Logger
|
||||||
from hparams import create_hparams
|
from hparams import create_hparams
|
||||||
|
|
||||||
|
|
||||||
def batchnorm_to_float(module):
|
|
||||||
"""Converts batch norm modules to FP32"""
|
|
||||||
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
|
||||||
module.float()
|
|
||||||
for child in module.children():
|
|
||||||
batchnorm_to_float(child)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def reduce_tensor(tensor, n_gpus):
|
def reduce_tensor(tensor, n_gpus):
|
||||||
rt = tensor.clone()
|
rt = tensor.clone()
|
||||||
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
||||||
|
|
@ -80,8 +69,7 @@ def prepare_directories_and_logger(output_directory, log_directory, rank):
|
||||||
def load_model(hparams):
|
def load_model(hparams):
|
||||||
model = Tacotron2(hparams).cuda()
|
model = Tacotron2(hparams).cuda()
|
||||||
if hparams.fp16_run:
|
if hparams.fp16_run:
|
||||||
model = batchnorm_to_float(model.half())
|
model.decoder.attention_layer.score_mask_value = finfo('float16').min
|
||||||
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
|
|
||||||
|
|
||||||
if hparams.distributed_run:
|
if hparams.distributed_run:
|
||||||
model = apply_gradient_allreduce(model)
|
model = apply_gradient_allreduce(model)
|
||||||
|
|
@ -177,9 +165,11 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
|
||||||
learning_rate = hparams.learning_rate
|
learning_rate = hparams.learning_rate
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
|
||||||
weight_decay=hparams.weight_decay)
|
weight_decay=hparams.weight_decay)
|
||||||
|
|
||||||
if hparams.fp16_run:
|
if hparams.fp16_run:
|
||||||
optimizer = FP16_Optimizer(
|
from apex import amp
|
||||||
optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling)
|
model, optimizer = amp.initialize(
|
||||||
|
model, optimizer, opt_level='O2')
|
||||||
|
|
||||||
if hparams.distributed_run:
|
if hparams.distributed_run:
|
||||||
model = apply_gradient_allreduce(model)
|
model = apply_gradient_allreduce(model)
|
||||||
|
|
@ -207,6 +197,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
|
||||||
epoch_offset = max(0, int(iteration / len(train_loader)))
|
epoch_offset = max(0, int(iteration / len(train_loader)))
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
is_overflow = False
|
||||||
# ================ MAIN TRAINNIG LOOP! ===================
|
# ================ MAIN TRAINNIG LOOP! ===================
|
||||||
for epoch in range(epoch_offset, hparams.epochs):
|
for epoch in range(epoch_offset, hparams.epochs):
|
||||||
print("Epoch: {}".format(epoch))
|
print("Epoch: {}".format(epoch))
|
||||||
|
|
@ -224,27 +215,30 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
|
||||||
reduced_loss = reduce_tensor(loss.data, n_gpus).item()
|
reduced_loss = reduce_tensor(loss.data, n_gpus).item()
|
||||||
else:
|
else:
|
||||||
reduced_loss = loss.item()
|
reduced_loss = loss.item()
|
||||||
|
|
||||||
if hparams.fp16_run:
|
if hparams.fp16_run:
|
||||||
optimizer.backward(loss)
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh)
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
if hparams.fp16_run:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
amp.master_params(optimizer), hparams.grad_clip_thresh)
|
||||||
|
is_overflow = math.isnan(grad_norm)
|
||||||
|
else:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
model.parameters(), hparams.grad_clip_thresh)
|
model.parameters(), hparams.grad_clip_thresh)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
overflow = optimizer.overflow if hparams.fp16_run else False
|
if not is_overflow and rank == 0:
|
||||||
|
|
||||||
if not overflow and not math.isnan(reduced_loss) and rank == 0:
|
|
||||||
duration = time.perf_counter() - start
|
duration = time.perf_counter() - start
|
||||||
print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
|
print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
|
||||||
iteration, reduced_loss, grad_norm, duration))
|
iteration, reduced_loss, grad_norm, duration))
|
||||||
logger.log_training(
|
logger.log_training(
|
||||||
reduced_loss, grad_norm, learning_rate, duration, iteration)
|
reduced_loss, grad_norm, learning_rate, duration, iteration)
|
||||||
|
|
||||||
if not overflow and (iteration % hparams.iters_per_checkpoint == 0):
|
if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0):
|
||||||
validate(model, criterion, valset, iteration,
|
validate(model, criterion, valset, iteration,
|
||||||
hparams.batch_size, n_gpus, collate_fn, logger,
|
hparams.batch_size, n_gpus, collate_fn, logger,
|
||||||
hparams.distributed_run, rank)
|
hparams.distributed_run, rank)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue