mirror of https://github.com/malarinv/tacotron2
load waveglow model from statedict
parent
08ad9ce16e
commit
4be2475cc1
19
tts.py
19
tts.py
|
|
@ -16,10 +16,25 @@ from librosa import resample
|
|||
from librosa.effects import time_stretch
|
||||
from sia.file_utils import cached_model_path
|
||||
from sia.instruments import do_time
|
||||
from glow import WaveGlow
|
||||
|
||||
TTS_SAMPLE_RATE = 22050
|
||||
OUTPUT_SAMPLE_RATE = 16000
|
||||
|
||||
# https://github.com/NVIDIA/waveglow/blob/master/config.json
|
||||
WAVEGLOW_CONFIG = {
|
||||
"n_mel_channels": 80,
|
||||
"n_flows": 12,
|
||||
"n_group": 8,
|
||||
"n_early_every": 4,
|
||||
"n_early_size": 2,
|
||||
"WN_config": {
|
||||
"n_layers": 8,
|
||||
"n_channels": 256,
|
||||
"kernel_size": 3
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TTSModel(object):
|
||||
"""docstring for TTSModel."""
|
||||
|
|
@ -34,7 +49,9 @@ class TTSModel(object):
|
|||
torch.load(tacotron2_path, map_location='cpu')['state_dict'])
|
||||
self.model.eval()
|
||||
waveglow_path = cached_model_path('waveglow_model')
|
||||
self.waveglow = torch.load(waveglow_path, map_location='cpu')['model']
|
||||
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
|
||||
wave_params = torch.load(waveglow_path, map_location='cpu')
|
||||
self.waveglow.load_state_dict(wave_params)
|
||||
self.waveglow.eval()
|
||||
for k in self.waveglow.convinv:
|
||||
k.float()
|
||||
|
|
|
|||
Loading…
Reference in New Issue