diff --git a/setup.py b/setup.py index 1758a6d..256b918 100644 --- a/setup.py +++ b/setup.py @@ -12,12 +12,12 @@ with open("HISTORY.rst") as history_file: requirements = [ "klepto==0.1.6", - "numpy==1.16.4", + "numpy~=1.16.4", "inflect==0.2.5", "librosa==0.6.0", - "scipy==1.3.0", + "scipy~=1.3.0", "Unidecode==1.0.22", - "torch==1.1.0", + "torch~=1.1.0", "PyAudio==0.2.11" ] @@ -53,7 +53,7 @@ setup( test_suite="tests", tests_require=test_requirements, url="https://github.com/malarinv/tacotron2", - version="0.2.0", + version="0.3.0", zip_safe=False, entry_points={"console_scripts": ("tts_debug = taco2.tts:main",)}, ) diff --git a/taco2/hparams.py b/taco2/hparams.py index 6c632dc..3f15f20 100644 --- a/taco2/hparams.py +++ b/taco2/hparams.py @@ -35,7 +35,7 @@ class HParams(object): # Audio Parameters # ################################ max_wav_value = 32768.0 - sampling_rate = 16000 + sampling_rate = 22050 filter_length = 1024 hop_length = 256 win_length = 1024 diff --git a/taco2/tts.py b/taco2/tts.py index 3c51885..5c505fe 100644 --- a/taco2/tts.py +++ b/taco2/tts.py @@ -15,9 +15,9 @@ from .text import text_to_sequence from .denoiser import Denoiser from .audio_processing import griffin_lim, postprocess_audio -TTS_SAMPLE_RATE = 22050 OUTPUT_SAMPLE_RATE = 22050 -# OUTPUT_SAMPLE_RATE = 16000 +GL_ITERS = 30 +VOCODER_MODEL = "wavglow" # config from # https://github.com/NVIDIA/waveglow/blob/master/config.json @@ -37,7 +37,7 @@ class TTSModel(object): def __init__(self, tacotron2_path, waveglow_path, **kwargs): super(TTSModel, self).__init__() hparams = HParams(**kwargs) - hparams.sampling_rate = TTS_SAMPLE_RATE + self.hparams = hparams self.model = Tacotron2(hparams) if torch.cuda.is_available(): self.model.load_state_dict(torch.load(tacotron2_path)["state_dict"]) @@ -78,7 +78,7 @@ class TTSModel(object): ) else: self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)( - self.synth_speech_gl + self.synth_speech_fast ) self.taco_stft = TacotronSTFT( hparams.filter_length, @@ -101,23 +101,40 @@ class TTSModel(object): ) return mel_outputs_postnet - def synth_speech_array(self, text): + def synth_speech_array(self, text, vocoder): mel_outputs_postnet = self.generate_mel_postnet(text) - with torch.no_grad(): - audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666) - audio_t = self.denoiser(audio_t, 0.1)[0] - audio = audio_t[0].data.cpu().numpy() + if method == "wavglow": + with torch.no_grad(): + audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666) + audio_t = self.denoiser(audio_t, 0.1)[0] + audio = audio_t[0].data.cpu().numpy() + elif method == "gl": + mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet) + mel_decompress = mel_decompress.transpose(1, 2).data.cpu() + spec_from_mel_scaling = 1000 + spec_from_mel = torch.mm(mel_decompress[0], self.taco_stft.mel_basis) + spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) + spec_from_mel = spec_from_mel * spec_from_mel_scaling + audio = griffin_lim( + torch.autograd.Variable(spec_from_mel[:, :, :-1]), + self.taco_stft.stft_fn, + 60, + ) + audio = audio.squeeze() + audio = audio.cpu().numpy() + else: + raise ValueError("vocoder arg should be one of [wavglow|gl]") return audio - def synth_speech(self, text): - audio = self.synth_speech_array(text) + def synth_speech(self, text, speed: 1.0, sample_rate=OUTPUT_SAMPLE_RATE): + audio = self.synth_speech_array(text, VOCODER_MODEL) return postprocess_audio( - audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE + audio, src_rate=self.hparams.sample_rate, dst_rate=sample_rate, tempo=speed ) - def synth_speech_gl(self, text, griffin_iters=60): + def synth_speech_fast(self, text, speed: 1.0, sample_rate=OUTPUT_SAMPLE_RATE): mel_outputs_postnet = self.generate_mel_postnet(text) mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet) @@ -129,13 +146,13 @@ class TTSModel(object): audio = griffin_lim( torch.autograd.Variable(spec_from_mel[:, :, :-1]), self.taco_stft.stft_fn, - griffin_iters, + GL_ITERS, ) audio = audio.squeeze() audio = audio.cpu().numpy() return postprocess_audio( - audio, tempo=0.6, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE + audio, tempo=speed, src_rate=self.hparams.sample_rate, dst_rate=sample_rate, )