mirror of
https://github.com/malarinv/tacotron2
synced 2026-03-08 01:32:35 +00:00
train.py: patching score_mask_value formerly inf, not concrete value, for compatibility with pytorch
This commit is contained in:
5
train.py
5
train.py
@@ -2,6 +2,7 @@ import os
|
||||
import time
|
||||
import argparse
|
||||
import math
|
||||
from numpy import finfo
|
||||
|
||||
import torch
|
||||
from distributed import DistributedDataParallel
|
||||
@@ -77,7 +78,9 @@ def prepare_directories_and_logger(output_directory, log_directory, rank):
|
||||
|
||||
def load_model(hparams):
|
||||
model = Tacotron2(hparams).cuda()
|
||||
model = batchnorm_to_float(model.half()) if hparams.fp16_run else model
|
||||
if hparams.fp16_run:
|
||||
model = batchnorm_to_float(model.half())
|
||||
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
|
||||
|
||||
if hparams.distributed_run:
|
||||
model = DistributedDataParallel(model)
|
||||
|
||||
Reference in New Issue
Block a user