load waveglow model from statedict

master
Malar Kannan 2019-07-03 14:05:10 +05:30
parent 08ad9ce16e
commit 4be2475cc1
1 changed files with 18 additions and 1 deletions

19
tts.py
View File

@ -16,10 +16,25 @@ from librosa import resample
from librosa.effects import time_stretch
from sia.file_utils import cached_model_path
from sia.instruments import do_time
from glow import WaveGlow
TTS_SAMPLE_RATE = 22050
OUTPUT_SAMPLE_RATE = 16000
# https://github.com/NVIDIA/waveglow/blob/master/config.json
WAVEGLOW_CONFIG = {
"n_mel_channels": 80,
"n_flows": 12,
"n_group": 8,
"n_early_every": 4,
"n_early_size": 2,
"WN_config": {
"n_layers": 8,
"n_channels": 256,
"kernel_size": 3
}
}
class TTSModel(object):
"""docstring for TTSModel."""
@ -34,7 +49,9 @@ class TTSModel(object):
torch.load(tacotron2_path, map_location='cpu')['state_dict'])
self.model.eval()
waveglow_path = cached_model_path('waveglow_model')
self.waveglow = torch.load(waveglow_path, map_location='cpu')['model']
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
wave_params = torch.load(waveglow_path, map_location='cpu')
self.waveglow.load_state_dict(wave_params)
self.waveglow.eval()
for k in self.waveglow.convinv:
k.float()