mirror of https://github.com/malarinv/tacotron2
Compare commits
2 Commits
ea11c5199e
...
42a85d177e
| Author | SHA1 | Date |
|---|---|---|
|
|
42a85d177e | |
|
|
5efb1e2758 |
48
taco2/tts.py
48
taco2/tts.py
|
|
@ -17,7 +17,7 @@ from .audio_processing import griffin_lim, postprocess_audio
|
||||||
|
|
||||||
OUTPUT_SAMPLE_RATE = 22050
|
OUTPUT_SAMPLE_RATE = 22050
|
||||||
GL_ITERS = 30
|
GL_ITERS = 30
|
||||||
VOCODER_MODEL = "wavglow"
|
VOCODER_WAVEGLOW, VOCODER_GL = "wavglow", "gl"
|
||||||
|
|
||||||
# config from
|
# config from
|
||||||
# https://github.com/NVIDIA/waveglow/blob/master/config.json
|
# https://github.com/NVIDIA/waveglow/blob/master/config.json
|
||||||
|
|
@ -74,11 +74,11 @@ class TTSModel(object):
|
||||||
self.waveglow, n_mel_channels=hparams.n_mel_channels
|
self.waveglow, n_mel_channels=hparams.n_mel_channels
|
||||||
)
|
)
|
||||||
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
|
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
|
||||||
self.synth_speech
|
self._synth_speech
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
|
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
|
||||||
self.synth_speech_fast
|
self._synth_speech_fast
|
||||||
)
|
)
|
||||||
self.taco_stft = TacotronSTFT(
|
self.taco_stft = TacotronSTFT(
|
||||||
hparams.filter_length,
|
hparams.filter_length,
|
||||||
|
|
@ -89,7 +89,7 @@ class TTSModel(object):
|
||||||
mel_fmax=4000,
|
mel_fmax=4000,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_mel_postnet(self, text):
|
def _generate_mel_postnet(self, text):
|
||||||
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
|
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
|
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
|
||||||
|
|
@ -102,14 +102,14 @@ class TTSModel(object):
|
||||||
return mel_outputs_postnet
|
return mel_outputs_postnet
|
||||||
|
|
||||||
def synth_speech_array(self, text, vocoder):
|
def synth_speech_array(self, text, vocoder):
|
||||||
mel_outputs_postnet = self.generate_mel_postnet(text)
|
mel_outputs_postnet = self._generate_mel_postnet(text)
|
||||||
|
|
||||||
if vocoder == "wavglow":
|
if vocoder == VOCODER_WAVEGLOW:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
|
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
|
||||||
audio_t = self.denoiser(audio_t, 0.1)[0]
|
audio_t = self.denoiser(audio_t, 0.1)[0]
|
||||||
audio = audio_t[0].data
|
audio = audio_t[0].data
|
||||||
elif vocoder == "gl":
|
elif vocoder == VOCODER_GL:
|
||||||
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
|
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
|
||||||
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
|
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
|
||||||
spec_from_mel_scaling = 1000
|
spec_from_mel_scaling = 1000
|
||||||
|
|
@ -122,7 +122,7 @@ class TTSModel(object):
|
||||||
audio = griffin_lim(
|
audio = griffin_lim(
|
||||||
torch.autograd.Variable(spec_from_mel[:, :, :-1]),
|
torch.autograd.Variable(spec_from_mel[:, :, :-1]),
|
||||||
self.taco_stft.stft_fn,
|
self.taco_stft.stft_fn,
|
||||||
60,
|
GL_ITERS,
|
||||||
)
|
)
|
||||||
audio = audio.squeeze()
|
audio = audio.squeeze()
|
||||||
else:
|
else:
|
||||||
|
|
@ -130,36 +130,28 @@ class TTSModel(object):
|
||||||
audio = audio.cpu().numpy()
|
audio = audio.cpu().numpy()
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def synth_speech(
|
def _synth_speech(
|
||||||
self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE
|
self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE
|
||||||
):
|
):
|
||||||
audio = self.synth_speech_array(text, VOCODER_MODEL)
|
audio = self.synth_speech_array(text, VOCODER_WAVEGLOW)
|
||||||
|
|
||||||
return postprocess_audio(
|
return postprocess_audio(
|
||||||
audio, src_rate=self.hparams.sampling_rate, dst_rate=sample_rate, tempo=speed
|
audio,
|
||||||
|
src_rate=self.hparams.sampling_rate,
|
||||||
|
dst_rate=sample_rate,
|
||||||
|
tempo=speed,
|
||||||
)
|
)
|
||||||
|
|
||||||
def synth_speech_fast(
|
def _synth_speech_fast(
|
||||||
self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE
|
self, text, speed: float = 1.0, sample_rate: int = OUTPUT_SAMPLE_RATE
|
||||||
):
|
):
|
||||||
mel_outputs_postnet = self.generate_mel_postnet(text)
|
audio = self.synth_speech_array(text, VOCODER_GL)
|
||||||
|
|
||||||
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
|
|
||||||
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
|
|
||||||
spec_from_mel_scaling = 1000
|
|
||||||
spec_from_mel = torch.mm(mel_decompress[0], self.taco_stft.mel_basis)
|
|
||||||
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
|
|
||||||
spec_from_mel = spec_from_mel * spec_from_mel_scaling
|
|
||||||
audio = griffin_lim(
|
|
||||||
torch.autograd.Variable(spec_from_mel[:, :, :-1]),
|
|
||||||
self.taco_stft.stft_fn,
|
|
||||||
GL_ITERS,
|
|
||||||
)
|
|
||||||
audio = audio.squeeze()
|
|
||||||
audio = audio.cpu().numpy()
|
|
||||||
|
|
||||||
return postprocess_audio(
|
return postprocess_audio(
|
||||||
audio, tempo=speed, src_rate=self.hparams.sampling_rate, dst_rate=sample_rate
|
audio,
|
||||||
|
tempo=speed,
|
||||||
|
src_rate=self.hparams.sampling_rate,
|
||||||
|
dst_rate=sample_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue