mirror of https://github.com/malarinv/tacotron2
tested gl/wavglow working
parent
78eed2d295
commit
ea11c5199e
|
|
@ -84,8 +84,8 @@ class STFT(torch.nn.Module):
|
||||||
forward_basis *= fft_window
|
forward_basis *= fft_window
|
||||||
inverse_basis *= fft_window
|
inverse_basis *= fft_window
|
||||||
|
|
||||||
self.register_buffer("forward_basis", forward_basis.float())
|
self.register_buffer("forward_basis", forward_basis.float().to(DEVICE))
|
||||||
self.register_buffer("inverse_basis", inverse_basis.float())
|
self.register_buffer("inverse_basis", inverse_basis.float().to(DEVICE))
|
||||||
|
|
||||||
def transform(self, input_data):
|
def transform(self, input_data):
|
||||||
num_batches = input_data.size(0)
|
num_batches = input_data.size(0)
|
||||||
|
|
@ -121,10 +121,10 @@ class STFT(torch.nn.Module):
|
||||||
return magnitude, phase
|
return magnitude, phase
|
||||||
|
|
||||||
def inverse(self, magnitude, phase):
|
def inverse(self, magnitude, phase):
|
||||||
|
phase = phase.to(DEVICE)
|
||||||
recombine_magnitude_phase = torch.cat(
|
recombine_magnitude_phase = torch.cat(
|
||||||
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
inverse_transform = F.conv_transpose1d(
|
inverse_transform = F.conv_transpose1d(
|
||||||
recombine_magnitude_phase,
|
recombine_magnitude_phase,
|
||||||
Variable(self.inverse_basis, requires_grad=False),
|
Variable(self.inverse_basis, requires_grad=False),
|
||||||
|
|
@ -144,11 +144,10 @@ class STFT(torch.nn.Module):
|
||||||
# remove modulation effects
|
# remove modulation effects
|
||||||
approx_nonzero_indices = torch.from_numpy(
|
approx_nonzero_indices = torch.from_numpy(
|
||||||
np.where(window_sum > tiny(window_sum))[0]
|
np.where(window_sum > tiny(window_sum))[0]
|
||||||
)
|
).to(DEVICE)
|
||||||
window_sum = torch.autograd.Variable(
|
window_sum = torch.autograd.Variable(
|
||||||
torch.from_numpy(window_sum), requires_grad=False
|
torch.from_numpy(window_sum), requires_grad=False
|
||||||
)
|
).to(DEVICE)
|
||||||
window_sum = window_sum.to(DEVICE)
|
|
||||||
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
||||||
approx_nonzero_indices
|
approx_nonzero_indices
|
||||||
]
|
]
|
||||||
|
|
|
||||||
19
taco2/tts.py
19
taco2/tts.py
|
|
@ -108,7 +108,7 @@ class TTSModel(object):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
|
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
|
||||||
audio_t = self.denoiser(audio_t, 0.1)[0]
|
audio_t = self.denoiser(audio_t, 0.1)[0]
|
||||||
audio = audio_t[0].data.cpu().numpy()
|
audio = audio_t[0].data
|
||||||
elif vocoder == "gl":
|
elif vocoder == "gl":
|
||||||
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
|
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
|
||||||
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
|
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
|
||||||
|
|
@ -116,25 +116,32 @@ class TTSModel(object):
|
||||||
spec_from_mel = torch.mm(mel_decompress[0], self.taco_stft.mel_basis)
|
spec_from_mel = torch.mm(mel_decompress[0], self.taco_stft.mel_basis)
|
||||||
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
|
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
|
||||||
spec_from_mel = spec_from_mel * spec_from_mel_scaling
|
spec_from_mel = spec_from_mel * spec_from_mel_scaling
|
||||||
|
spec_from_mel = (
|
||||||
|
spec_from_mel.cuda() if torch.cuda.is_available() else spec_from_mel
|
||||||
|
)
|
||||||
audio = griffin_lim(
|
audio = griffin_lim(
|
||||||
torch.autograd.Variable(spec_from_mel[:, :, :-1]),
|
torch.autograd.Variable(spec_from_mel[:, :, :-1]),
|
||||||
self.taco_stft.stft_fn,
|
self.taco_stft.stft_fn,
|
||||||
60,
|
60,
|
||||||
)
|
)
|
||||||
audio = audio.squeeze()
|
audio = audio.squeeze()
|
||||||
audio = audio.cpu().numpy()
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("vocoder arg should be one of [wavglow|gl]")
|
raise ValueError("vocoder arg should be one of [wavglow|gl]")
|
||||||
|
audio = audio.cpu().numpy()
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def synth_speech(self, text, speed: 1.0, sample_rate=OUTPUT_SAMPLE_RATE):
|
def synth_speech(
|
||||||
|
self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE
|
||||||
|
):
|
||||||
audio = self.synth_speech_array(text, VOCODER_MODEL)
|
audio = self.synth_speech_array(text, VOCODER_MODEL)
|
||||||
|
|
||||||
return postprocess_audio(
|
return postprocess_audio(
|
||||||
audio, src_rate=self.hparams.sample_rate, dst_rate=sample_rate, tempo=speed
|
audio, src_rate=self.hparams.sampling_rate, dst_rate=sample_rate, tempo=speed
|
||||||
)
|
)
|
||||||
|
|
||||||
def synth_speech_fast(self, text, speed: 1.0, sample_rate=OUTPUT_SAMPLE_RATE):
|
def synth_speech_fast(
|
||||||
|
self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE
|
||||||
|
):
|
||||||
mel_outputs_postnet = self.generate_mel_postnet(text)
|
mel_outputs_postnet = self.generate_mel_postnet(text)
|
||||||
|
|
||||||
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
|
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
|
||||||
|
|
@ -152,7 +159,7 @@ class TTSModel(object):
|
||||||
audio = audio.cpu().numpy()
|
audio = audio.cpu().numpy()
|
||||||
|
|
||||||
return postprocess_audio(
|
return postprocess_audio(
|
||||||
audio, tempo=speed, src_rate=self.hparams.sample_rate, dst_rate=sample_rate,
|
audio, tempo=speed, src_rate=self.hparams.sampling_rate, dst_rate=sample_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue