tacotron2/taco2/tts.py

176 lines
5.3 KiB
Python
Raw Normal View History

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import torch
import pyaudio
2019-09-20 19:49:30 +00:00
import klepto
2019-10-09 10:33:29 +00:00
import argparse
from pathlib import Path
from .model import Tacotron2
2019-09-20 19:49:30 +00:00
from glow import WaveGlow
from .hparams import HParams
from .layers import TacotronSTFT
2019-09-20 19:49:30 +00:00
from .text import text_to_sequence
from .denoiser import Denoiser
from .audio_processing import griffin_lim, postprocess_audio
TTS_SAMPLE_RATE = 22050
OUTPUT_SAMPLE_RATE = 16000
2019-07-05 09:03:04 +00:00
# config from
# https://github.com/NVIDIA/waveglow/blob/master/config.json
WAVEGLOW_CONFIG = {
"n_mel_channels": 40,
"n_flows": 12,
"n_group": 8,
"n_early_every": 4,
"n_early_size": 2,
2019-07-03 12:40:16 +00:00
"WN_config": {"n_layers": 8, "n_channels": 256, "kernel_size": 3},
}
class TTSModel(object):
"""docstring for TTSModel."""
def __init__(self, tacotron2_path, waveglow_path, **kwargs):
super(TTSModel, self).__init__()
hparams = HParams(**kwargs)
hparams.sampling_rate = TTS_SAMPLE_RATE
self.model = Tacotron2(hparams)
self.model.load_state_dict(
2019-07-03 12:40:16 +00:00
torch.load(tacotron2_path, map_location="cpu")["state_dict"]
)
self.model.eval()
self.k_cache = klepto.archives.file_archive(cached=False)
if waveglow_path:
wave_params = torch.load(waveglow_path, map_location="cpu")
try:
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
self.waveglow.load_state_dict(wave_params)
self.waveglow.eval()
except:
self.waveglow = wave_params["model"]
self.waveglow = self.waveglow.remove_weightnorm(self.waveglow)
self.waveglow.eval()
# workaround from
# https://github.com/NVIDIA/waveglow/issues/127
for m in self.waveglow.modules():
if "Conv" in str(type(m)):
setattr(m, "padding_mode", "zeros")
for k in self.waveglow.convinv:
k.float()
self.denoiser = Denoiser(
self.waveglow, n_mel_channels=hparams.n_mel_channels
)
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
self.synth_speech
)
else:
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(
self.synth_speech_gl
)
self.taco_stft = TacotronSTFT(
hparams.filter_length,
hparams.hop_length,
hparams.win_length,
n_mel_channels=hparams.n_mel_channels,
sampling_rate=hparams.sampling_rate,
mel_fmax=4000,
)
def generate_mel_postnet(self, text):
2019-09-20 19:49:30 +00:00
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
with torch.no_grad():
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(
sequence
)
return mel_outputs_postnet
def synth_speech(self, text):
mel_outputs_postnet = self.generate_mel_postnet(text)
with torch.no_grad():
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
audio_t = self.denoiser(audio_t, 0.1)[0]
audio = audio_t[0].data.cpu().numpy()
return postprocess_audio(
audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE
)
def synth_speech_gl(self, text, griffin_iters=60):
mel_outputs_postnet = self.generate_mel_postnet(text)
mel_decompress = self.taco_stft.spectral_de_normalize(mel_outputs_postnet)
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
spec_from_mel_scaling = 1000
spec_from_mel = torch.mm(mel_decompress[0], self.taco_stft.mel_basis)
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
spec_from_mel = spec_from_mel * spec_from_mel_scaling
audio = griffin_lim(
torch.autograd.Variable(spec_from_mel[:, :, :-1]),
self.taco_stft.stft_fn,
griffin_iters,
)
audio = audio.squeeze()
audio = audio.cpu().numpy()
return postprocess_audio(
audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE
)
def player_gen():
audio_interface = pyaudio.PyAudio()
2019-07-03 12:40:16 +00:00
_audio_stream = audio_interface.open(
2019-09-20 19:49:30 +00:00
format=pyaudio.paInt16, channels=1, rate=OUTPUT_SAMPLE_RATE, output=True
2019-07-03 12:40:16 +00:00
)
def play_device(data):
_audio_stream.write(data)
# _audio_stream.close()
return play_device
2019-10-09 10:33:29 +00:00
def repl(tts_model):
2019-09-20 19:49:30 +00:00
player = player_gen()
2019-09-20 19:49:30 +00:00
def loop():
text = input("tts >")
2019-09-20 19:49:30 +00:00
data = tts_model.synth_speech(text.strip())
player(data)
return loop
def main():
2019-10-09 10:33:29 +00:00
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"-t",
"--tacotron2_path",
type=Path,
default="./tacotron.pt",
help="Path to a tacotron2 model",
)
parser.add_argument(
"-w",
"--waveglow_path",
type=Path,
default="./waveglow_256channels.pt",
help="Path to a waveglow model",
)
args = parser.parse_args()
tts_model = TTSModel(**vars(args))
interactive_loop = repl(tts_model)
2019-09-20 19:49:30 +00:00
while True:
interactive_loop()
2019-07-03 12:40:16 +00:00
if __name__ == "__main__":
main()