mirror of https://github.com/malarinv/tacotron2
load waveglow model from statedict
parent
08ad9ce16e
commit
4be2475cc1
19
tts.py
19
tts.py
|
|
@ -16,10 +16,25 @@ from librosa import resample
|
||||||
from librosa.effects import time_stretch
|
from librosa.effects import time_stretch
|
||||||
from sia.file_utils import cached_model_path
|
from sia.file_utils import cached_model_path
|
||||||
from sia.instruments import do_time
|
from sia.instruments import do_time
|
||||||
|
from glow import WaveGlow
|
||||||
|
|
||||||
TTS_SAMPLE_RATE = 22050
|
TTS_SAMPLE_RATE = 22050
|
||||||
OUTPUT_SAMPLE_RATE = 16000
|
OUTPUT_SAMPLE_RATE = 16000
|
||||||
|
|
||||||
|
# https://github.com/NVIDIA/waveglow/blob/master/config.json
|
||||||
|
WAVEGLOW_CONFIG = {
|
||||||
|
"n_mel_channels": 80,
|
||||||
|
"n_flows": 12,
|
||||||
|
"n_group": 8,
|
||||||
|
"n_early_every": 4,
|
||||||
|
"n_early_size": 2,
|
||||||
|
"WN_config": {
|
||||||
|
"n_layers": 8,
|
||||||
|
"n_channels": 256,
|
||||||
|
"kernel_size": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TTSModel(object):
|
class TTSModel(object):
|
||||||
"""docstring for TTSModel."""
|
"""docstring for TTSModel."""
|
||||||
|
|
@ -34,7 +49,9 @@ class TTSModel(object):
|
||||||
torch.load(tacotron2_path, map_location='cpu')['state_dict'])
|
torch.load(tacotron2_path, map_location='cpu')['state_dict'])
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
waveglow_path = cached_model_path('waveglow_model')
|
waveglow_path = cached_model_path('waveglow_model')
|
||||||
self.waveglow = torch.load(waveglow_path, map_location='cpu')['model']
|
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
|
||||||
|
wave_params = torch.load(waveglow_path, map_location='cpu')
|
||||||
|
self.waveglow.load_state_dict(wave_params)
|
||||||
self.waveglow.eval()
|
self.waveglow.eval()
|
||||||
for k in self.waveglow.convinv:
|
for k in self.waveglow.convinv:
|
||||||
k.float()
|
k.float()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue