From ea11c5199ed3ca9814b08cab4d815180fc80a08c Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Thu, 28 Nov 2019 17:52:05 +0530 Subject: [PATCH] tested gl/wavglow working --- taco2/stft.py | 11 +++++------ taco2/tts.py | 19 +++++++++++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/taco2/stft.py b/taco2/stft.py index 827362b..21e29d0 100644 --- a/taco2/stft.py +++ b/taco2/stft.py @@ -84,8 +84,8 @@ class STFT(torch.nn.Module): forward_basis *= fft_window inverse_basis *= fft_window - self.register_buffer("forward_basis", forward_basis.float()) - self.register_buffer("inverse_basis", inverse_basis.float()) + self.register_buffer("forward_basis", forward_basis.float().to(DEVICE)) + self.register_buffer("inverse_basis", inverse_basis.float().to(DEVICE)) def transform(self, input_data): num_batches = input_data.size(0) @@ -121,10 +121,10 @@ class STFT(torch.nn.Module): return magnitude, phase def inverse(self, magnitude, phase): + phase = phase.to(DEVICE) recombine_magnitude_phase = torch.cat( [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 ) - inverse_transform = F.conv_transpose1d( recombine_magnitude_phase, Variable(self.inverse_basis, requires_grad=False), @@ -144,11 +144,10 @@ class STFT(torch.nn.Module): # remove modulation effects approx_nonzero_indices = torch.from_numpy( np.where(window_sum > tiny(window_sum))[0] - ) + ).to(DEVICE) window_sum = torch.autograd.Variable( torch.from_numpy(window_sum), requires_grad=False - ) - window_sum = window_sum.to(DEVICE) + ).to(DEVICE) inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ approx_nonzero_indices ] diff --git a/taco2/tts.py b/taco2/tts.py index 176a839..2f3405e 100644 --- a/taco2/tts.py +++ b/taco2/tts.py @@ -108,7 +108,7 @@ class TTSModel(object): with torch.no_grad(): audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666) audio_t = self.denoiser(audio_t, 0.1)[0] - audio = audio_t[0].data.cpu().numpy() + audio = audio_t[0].data elif vocoder == "gl": mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet) mel_decompress = mel_decompress.transpose(1, 2).data.cpu() @@ -116,25 +116,32 @@ class TTSModel(object): spec_from_mel = torch.mm(mel_decompress[0], self.taco_stft.mel_basis) spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) spec_from_mel = spec_from_mel * spec_from_mel_scaling + spec_from_mel = ( + spec_from_mel.cuda() if torch.cuda.is_available() else spec_from_mel + ) audio = griffin_lim( torch.autograd.Variable(spec_from_mel[:, :, :-1]), self.taco_stft.stft_fn, 60, ) audio = audio.squeeze() - audio = audio.cpu().numpy() else: raise ValueError("vocoder arg should be one of [wavglow|gl]") + audio = audio.cpu().numpy() return audio - def synth_speech(self, text, speed: 1.0, sample_rate=OUTPUT_SAMPLE_RATE): + def synth_speech( + self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE + ): audio = self.synth_speech_array(text, VOCODER_MODEL) return postprocess_audio( - audio, src_rate=self.hparams.sample_rate, dst_rate=sample_rate, tempo=speed + audio, src_rate=self.hparams.sampling_rate, dst_rate=sample_rate, tempo=speed ) - def synth_speech_fast(self, text, speed: 1.0, sample_rate=OUTPUT_SAMPLE_RATE): + def synth_speech_fast( + self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE + ): mel_outputs_postnet = self.generate_mel_postnet(text) mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet) @@ -152,7 +159,7 @@ class TTSModel(object): audio = audio.cpu().numpy() return postprocess_audio( - audio, tempo=speed, src_rate=self.hparams.sample_rate, dst_rate=sample_rate, + audio, tempo=speed, src_rate=self.hparams.sampling_rate, dst_rate=sample_rate )