diff --git a/taco2/tts.py b/taco2/tts.py index 2f3405e..7f88252 100644 --- a/taco2/tts.py +++ b/taco2/tts.py @@ -17,7 +17,7 @@ from .audio_processing import griffin_lim, postprocess_audio OUTPUT_SAMPLE_RATE = 22050 GL_ITERS = 30 -VOCODER_MODEL = "wavglow" +VOCODER_WAVEGLOW, VOCODER_GL = "wavglow", "gl" # config from # https://github.com/NVIDIA/waveglow/blob/master/config.json @@ -74,11 +74,11 @@ class TTSModel(object): self.waveglow, n_mel_channels=hparams.n_mel_channels ) self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)( - self.synth_speech + self._synth_speech ) else: self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)( - self.synth_speech_fast + self._synth_speech_fast ) self.taco_stft = TacotronSTFT( hparams.filter_length, @@ -89,7 +89,7 @@ class TTSModel(object): mel_fmax=4000, ) - def generate_mel_postnet(self, text): + def _generate_mel_postnet(self, text): sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :] if torch.cuda.is_available(): sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() @@ -102,14 +102,14 @@ class TTSModel(object): return mel_outputs_postnet def synth_speech_array(self, text, vocoder): - mel_outputs_postnet = self.generate_mel_postnet(text) + mel_outputs_postnet = self._generate_mel_postnet(text) - if vocoder == "wavglow": + if vocoder == VOCODER_WAVEGLOW: with torch.no_grad(): audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666) audio_t = self.denoiser(audio_t, 0.1)[0] audio = audio_t[0].data - elif vocoder == "gl": + elif vocoder == VOCODER_GL: mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet) mel_decompress = mel_decompress.transpose(1, 2).data.cpu() spec_from_mel_scaling = 1000 @@ -122,7 +122,7 @@ class TTSModel(object): audio = griffin_lim( torch.autograd.Variable(spec_from_mel[:, :, :-1]), self.taco_stft.stft_fn, - 60, + GL_ITERS, ) audio = audio.squeeze() else: @@ -130,36 +130,28 @@ class TTSModel(object): audio = audio.cpu().numpy() return audio - def synth_speech( + def _synth_speech( self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE ): - audio = self.synth_speech_array(text, VOCODER_MODEL) + audio = self.synth_speech_array(text, VOCODER_WAVEGLOW) return postprocess_audio( - audio, src_rate=self.hparams.sampling_rate, dst_rate=sample_rate, tempo=speed + audio, + src_rate=self.hparams.sampling_rate, + dst_rate=sample_rate, + tempo=speed, ) - def synth_speech_fast( + def _synth_speech_fast( self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE ): - mel_outputs_postnet = self.generate_mel_postnet(text) - - mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet) - mel_decompress = mel_decompress.transpose(1, 2).data.cpu() - spec_from_mel_scaling = 1000 - spec_from_mel = torch.mm(mel_decompress[0], self.taco_stft.mel_basis) - spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) - spec_from_mel = spec_from_mel * spec_from_mel_scaling - audio = griffin_lim( - torch.autograd.Variable(spec_from_mel[:, :, :-1]), - self.taco_stft.stft_fn, - GL_ITERS, - ) - audio = audio.squeeze() - audio = audio.cpu().numpy() + audio = self.synth_speech_array(text, VOCODER_GL) return postprocess_audio( - audio, tempo=speed, src_rate=self.hparams.sampling_rate, dst_rate=sample_rate + audio, + tempo=speed, + src_rate=self.hparams.sampling_rate, + dst_rate=sample_rate, )