Compare commits

..

No commits in common. "42a85d177ea8c933cae26464eabfd247d55c394e" and "ea11c5199ed3ca9814b08cab4d815180fc80a08c" have entirely different histories.

2 changed files with 28 additions and 20 deletions

View File

@ -17,7 +17,7 @@ from .audio_processing import griffin_lim, postprocess_audio
OUTPUT_SAMPLE_RATE = 22050 OUTPUT_SAMPLE_RATE = 22050
GL_ITERS = 30 GL_ITERS = 30
VOCODER_WAVEGLOW, VOCODER_GL = "wavglow", "gl" VOCODER_MODEL = "wavglow"
# config from # config from
# https://github.com/NVIDIA/waveglow/blob/master/config.json # https://github.com/NVIDIA/waveglow/blob/master/config.json
@ -74,11 +74,11 @@ class TTSModel(object):
self.waveglow, n_mel_channels=hparams.n_mel_channels self.waveglow, n_mel_channels=hparams.n_mel_channels
) )
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)( self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
self._synth_speech self.synth_speech
) )
else: else:
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)( self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
self._synth_speech_fast self.synth_speech_fast
) )
self.taco_stft = TacotronSTFT( self.taco_stft = TacotronSTFT(
hparams.filter_length, hparams.filter_length,
@ -89,7 +89,7 @@ class TTSModel(object):
mel_fmax=4000, mel_fmax=4000,
) )
def _generate_mel_postnet(self, text): def generate_mel_postnet(self, text):
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :] sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
if torch.cuda.is_available(): if torch.cuda.is_available():
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
@ -102,14 +102,14 @@ class TTSModel(object):
return mel_outputs_postnet return mel_outputs_postnet
def synth_speech_array(self, text, vocoder): def synth_speech_array(self, text, vocoder):
mel_outputs_postnet = self._generate_mel_postnet(text) mel_outputs_postnet = self.generate_mel_postnet(text)
if vocoder == VOCODER_WAVEGLOW: if vocoder == "wavglow":
with torch.no_grad(): with torch.no_grad():
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666) audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
audio_t = self.denoiser(audio_t, 0.1)[0] audio_t = self.denoiser(audio_t, 0.1)[0]
audio = audio_t[0].data audio = audio_t[0].data
elif vocoder == VOCODER_GL: elif vocoder == "gl":
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet) mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
mel_decompress = mel_decompress.transpose(1, 2).data.cpu() mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
spec_from_mel_scaling = 1000 spec_from_mel_scaling = 1000
@ -122,7 +122,7 @@ class TTSModel(object):
audio = griffin_lim( audio = griffin_lim(
torch.autograd.Variable(spec_from_mel[:, :, :-1]), torch.autograd.Variable(spec_from_mel[:, :, :-1]),
self.taco_stft.stft_fn, self.taco_stft.stft_fn,
GL_ITERS, 60,
) )
audio = audio.squeeze() audio = audio.squeeze()
else: else:
@ -130,28 +130,36 @@ class TTSModel(object):
audio = audio.cpu().numpy() audio = audio.cpu().numpy()
return audio return audio
def _synth_speech( def synth_speech(
self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE
): ):
audio = self.synth_speech_array(text, VOCODER_WAVEGLOW) audio = self.synth_speech_array(text, VOCODER_MODEL)
return postprocess_audio( return postprocess_audio(
audio, audio, src_rate=self.hparams.sampling_rate, dst_rate=sample_rate, tempo=speed
src_rate=self.hparams.sampling_rate,
dst_rate=sample_rate,
tempo=speed,
) )
def _synth_speech_fast( def synth_speech_fast(
self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE
): ):
audio = self.synth_speech_array(text, VOCODER_GL) mel_outputs_postnet = self.generate_mel_postnet(text)
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
spec_from_mel_scaling = 1000
spec_from_mel = torch.mm(mel_decompress[0], self.taco_stft.mel_basis)
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
spec_from_mel = spec_from_mel * spec_from_mel_scaling
audio = griffin_lim(
torch.autograd.Variable(spec_from_mel[:, :, :-1]),
self.taco_stft.stft_fn,
GL_ITERS,
)
audio = audio.squeeze()
audio = audio.cpu().numpy()
return postprocess_audio( return postprocess_audio(
audio, audio, tempo=speed, src_rate=self.hparams.sampling_rate, dst_rate=sample_rate
tempo=speed,
src_rate=self.hparams.sampling_rate,
dst_rate=sample_rate,
) )