#!/usr/bin/env python # -*- coding: utf-8 -*- import numpy as np import torch import pyaudio import klepto from .model import Tacotron2 from glow import WaveGlow from .hparams import HParams from .layers import TacotronSTFT from .text import text_to_sequence from .denoiser import Denoiser from .audio_processing import griffin_lim, postprocess_audio TTS_SAMPLE_RATE = 22050 OUTPUT_SAMPLE_RATE = 16000 # config from # https://github.com/NVIDIA/waveglow/blob/master/config.json WAVEGLOW_CONFIG = { "n_mel_channels": 40, "n_flows": 12, "n_group": 8, "n_early_every": 4, "n_early_size": 2, "WN_config": {"n_layers": 8, "n_channels": 256, "kernel_size": 3}, } class TTSModel(object): """docstring for TTSModel.""" def __init__(self, tacotron2_path, waveglow_path, **kwargs): super(TTSModel, self).__init__() hparams = HParams(**kwargs) hparams.sampling_rate = TTS_SAMPLE_RATE self.model = Tacotron2(hparams) self.model.load_state_dict( torch.load(tacotron2_path, map_location="cpu")["state_dict"] ) self.model.eval() self.k_cache = klepto.archives.file_archive(cached=False) if waveglow_path: wave_params = torch.load(waveglow_path, map_location="cpu") try: self.waveglow = WaveGlow(**WAVEGLOW_CONFIG) self.waveglow.load_state_dict(wave_params) self.waveglow.eval() except: self.waveglow = wave_params["model"] self.waveglow = self.waveglow.remove_weightnorm(self.waveglow) self.waveglow.eval() # workaround from # https://github.com/NVIDIA/waveglow/issues/127 for m in self.waveglow.modules(): if "Conv" in str(type(m)): setattr(m, "padding_mode", "zeros") for k in self.waveglow.convinv: k.float() self.denoiser = Denoiser( self.waveglow, n_mel_channels=hparams.n_mel_channels ) self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)( self.synth_speech ) else: self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)( self.synth_speech_gl ) self.taco_stft = TacotronSTFT( hparams.filter_length, hparams.hop_length, hparams.win_length, n_mel_channels=hparams.n_mel_channels, sampling_rate=hparams.sampling_rate, mel_fmax=4000, ) def generate_mel_postnet(self, text): sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :] sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long() with torch.no_grad(): mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference( sequence ) return mel_outputs_postnet def synth_speech(self, text): mel_outputs_postnet = self.generate_mel_postnet(text) with torch.no_grad(): audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666) audio_t = self.denoiser(audio_t, 0.1)[0] audio = audio_t[0].data.cpu().numpy() return postprocess_audio( audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE ) def synth_speech_gl(self, text, griffin_iters=60): mel_outputs_postnet = self.generate_mel_postnet(text) mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet) mel_decompress = mel_decompress.transpose(1, 2).data.cpu() spec_from_mel_scaling = 1000 spec_from_mel = torch.mm(mel_decompress[0], self.taco_stft.mel_basis) spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) spec_from_mel = spec_from_mel * spec_from_mel_scaling audio = griffin_lim( torch.autograd.Variable(spec_from_mel[:, :, :-1]), self.taco_stft.stft_fn, griffin_iters, ) audio = audio.squeeze() audio = audio.cpu().numpy() return postprocess_audio( audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE ) def player_gen(): audio_interface = pyaudio.PyAudio() _audio_stream = audio_interface.open( format=pyaudio.paInt16, channels=1, rate=OUTPUT_SAMPLE_RATE, output=True ) def play_device(data): _audio_stream.write(data) # _audio_stream.close() return play_device def repl(): tts_model = TTSModel( "/Users/malar/Work/tacotron2_r4_83000.pt", "/Users/malar/Work/waveglow_484000" ) player = player_gen() def loop(): text = input("tts >") data = tts_model.synth_speech(text.strip()) player(data) return loop def main(): interactive_loop = repl() while True: interactive_loop() if __name__ == "__main__": main()