diff --git a/tts.py b/tts.py index 30b7cbe..f115d83 100644 --- a/tts.py +++ b/tts.py @@ -16,10 +16,25 @@ from librosa import resample from librosa.effects import time_stretch from sia.file_utils import cached_model_path from sia.instruments import do_time +from glow import WaveGlow TTS_SAMPLE_RATE = 22050 OUTPUT_SAMPLE_RATE = 16000 +# https://github.com/NVIDIA/waveglow/blob/master/config.json +WAVEGLOW_CONFIG = { + "n_mel_channels": 80, + "n_flows": 12, + "n_group": 8, + "n_early_every": 4, + "n_early_size": 2, + "WN_config": { + "n_layers": 8, + "n_channels": 256, + "kernel_size": 3 + } +} + class TTSModel(object): """docstring for TTSModel.""" @@ -34,7 +49,9 @@ class TTSModel(object): torch.load(tacotron2_path, map_location='cpu')['state_dict']) self.model.eval() waveglow_path = cached_model_path('waveglow_model') - self.waveglow = torch.load(waveglow_path, map_location='cpu')['model'] + self.waveglow = WaveGlow(**WAVEGLOW_CONFIG) + wave_params = torch.load(waveglow_path, map_location='cpu') + self.waveglow.load_state_dict(wave_params) self.waveglow.eval() for k in self.waveglow.convinv: k.float()