mirror of https://github.com/malarinv/tacotron2
train.py: patching score_mask_value formerly inf, not concrete value, for compatibility with pytorch
parent
cd851585cb
commit
1071023017
5
train.py
5
train.py
|
|
@ -2,6 +2,7 @@ import os
|
|||
import time
|
||||
import argparse
|
||||
import math
|
||||
from numpy import finfo
|
||||
|
||||
import torch
|
||||
from distributed import DistributedDataParallel
|
||||
|
|
@ -77,7 +78,9 @@ def prepare_directories_and_logger(output_directory, log_directory, rank):
|
|||
|
||||
def load_model(hparams):
|
||||
model = Tacotron2(hparams).cuda()
|
||||
model = batchnorm_to_float(model.half()) if hparams.fp16_run else model
|
||||
if hparams.fp16_run:
|
||||
model = batchnorm_to_float(model.half())
|
||||
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
|
||||
|
||||
if hparams.distributed_run:
|
||||
model = DistributedDataParallel(model)
|
||||
|
|
|
|||
Loading…
Reference in New Issue