tested gl/wavglow working

master
Malar Kannan 2019-11-28 17:52:05 +05:30
parent 78eed2d295
commit ea11c5199e
2 changed files with 18 additions and 12 deletions

View File

@ -84,8 +84,8 @@ class STFT(torch.nn.Module):
forward_basis *= fft_window forward_basis *= fft_window
inverse_basis *= fft_window inverse_basis *= fft_window
self.register_buffer("forward_basis", forward_basis.float()) self.register_buffer("forward_basis", forward_basis.float().to(DEVICE))
self.register_buffer("inverse_basis", inverse_basis.float()) self.register_buffer("inverse_basis", inverse_basis.float().to(DEVICE))
def transform(self, input_data): def transform(self, input_data):
num_batches = input_data.size(0) num_batches = input_data.size(0)
@ -121,10 +121,10 @@ class STFT(torch.nn.Module):
return magnitude, phase return magnitude, phase
def inverse(self, magnitude, phase): def inverse(self, magnitude, phase):
phase = phase.to(DEVICE)
recombine_magnitude_phase = torch.cat( recombine_magnitude_phase = torch.cat(
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
) )
inverse_transform = F.conv_transpose1d( inverse_transform = F.conv_transpose1d(
recombine_magnitude_phase, recombine_magnitude_phase,
Variable(self.inverse_basis, requires_grad=False), Variable(self.inverse_basis, requires_grad=False),
@ -144,11 +144,10 @@ class STFT(torch.nn.Module):
# remove modulation effects # remove modulation effects
approx_nonzero_indices = torch.from_numpy( approx_nonzero_indices = torch.from_numpy(
np.where(window_sum > tiny(window_sum))[0] np.where(window_sum > tiny(window_sum))[0]
) ).to(DEVICE)
window_sum = torch.autograd.Variable( window_sum = torch.autograd.Variable(
torch.from_numpy(window_sum), requires_grad=False torch.from_numpy(window_sum), requires_grad=False
) ).to(DEVICE)
window_sum = window_sum.to(DEVICE)
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
approx_nonzero_indices approx_nonzero_indices
] ]

View File

@ -108,7 +108,7 @@ class TTSModel(object):
with torch.no_grad(): with torch.no_grad():
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666) audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
audio_t = self.denoiser(audio_t, 0.1)[0] audio_t = self.denoiser(audio_t, 0.1)[0]
audio = audio_t[0].data.cpu().numpy() audio = audio_t[0].data
elif vocoder == "gl": elif vocoder == "gl":
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet) mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
mel_decompress = mel_decompress.transpose(1, 2).data.cpu() mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
@ -116,25 +116,32 @@ class TTSModel(object):
spec_from_mel = torch.mm(mel_decompress[0], self.taco_stft.mel_basis) spec_from_mel = torch.mm(mel_decompress[0], self.taco_stft.mel_basis)
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
spec_from_mel = spec_from_mel * spec_from_mel_scaling spec_from_mel = spec_from_mel * spec_from_mel_scaling
spec_from_mel = (
spec_from_mel.cuda() if torch.cuda.is_available() else spec_from_mel
)
audio = griffin_lim( audio = griffin_lim(
torch.autograd.Variable(spec_from_mel[:, :, :-1]), torch.autograd.Variable(spec_from_mel[:, :, :-1]),
self.taco_stft.stft_fn, self.taco_stft.stft_fn,
60, 60,
) )
audio = audio.squeeze() audio = audio.squeeze()
audio = audio.cpu().numpy()
else: else:
raise ValueError("vocoder arg should be one of [wavglow|gl]") raise ValueError("vocoder arg should be one of [wavglow|gl]")
audio = audio.cpu().numpy()
return audio return audio
def synth_speech(self, text, speed: 1.0, sample_rate=OUTPUT_SAMPLE_RATE): def synth_speech(
self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE
):
audio = self.synth_speech_array(text, VOCODER_MODEL) audio = self.synth_speech_array(text, VOCODER_MODEL)
return postprocess_audio( return postprocess_audio(
audio, src_rate=self.hparams.sample_rate, dst_rate=sample_rate, tempo=speed audio, src_rate=self.hparams.sampling_rate, dst_rate=sample_rate, tempo=speed
) )
def synth_speech_fast(self, text, speed: 1.0, sample_rate=OUTPUT_SAMPLE_RATE): def synth_speech_fast(
self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE
):
mel_outputs_postnet = self.generate_mel_postnet(text) mel_outputs_postnet = self.generate_mel_postnet(text)
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet) mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
@ -152,7 +159,7 @@ class TTSModel(object):
audio = audio.cpu().numpy() audio = audio.cpu().numpy()
return postprocess_audio( return postprocess_audio(
audio, tempo=speed, src_rate=self.hparams.sample_rate, dst_rate=sample_rate, audio, tempo=speed, src_rate=self.hparams.sampling_rate, dst_rate=sample_rate
) )