From 0682eddfdcc6dececf1236d770ba52cf79c43ed9 Mon Sep 17 00:00:00 2001 From: Malar Kannan Date: Thu, 26 Sep 2019 10:53:09 +0530 Subject: [PATCH] 1. add griffin lim support 2. add denoiser 3. add support to handle old and new waveglow models --- taco2/convert_model.py | 72 +++++++++++++++++++++++++++++++++++++ taco2/data_utils.py | 1 + taco2/denoiser.py | 39 ++++++++++++++++++++ taco2/hparams.py | 1 + taco2/text_codec.py | 81 ++++++++++++++++++++++++++++++++++++++++++ taco2/tts.py | 65 ++++++++++++++++++++++++++------- 6 files changed, 246 insertions(+), 13 deletions(-) create mode 100644 taco2/convert_model.py create mode 100644 taco2/denoiser.py create mode 100644 taco2/text_codec.py diff --git a/taco2/convert_model.py b/taco2/convert_model.py new file mode 100644 index 0000000..32e77e2 --- /dev/null +++ b/taco2/convert_model.py @@ -0,0 +1,72 @@ +import sys +import copy +import torch + +def _check_model_old_version(model): + if hasattr(model.WN[0], 'res_layers') or hasattr(model.WN[0], 'cond_layers'): + return True + else: + return False + + +def _update_model_res_skip(old_model, new_model): + for idx in range(0, len(new_model.WN)): + wavenet = new_model.WN[idx] + n_channels = wavenet.n_channels + n_layers = wavenet.n_layers + wavenet.res_skip_layers = torch.nn.ModuleList() + for i in range(0, n_layers): + if i < n_layers - 1: + res_skip_channels = 2*n_channels + else: + res_skip_channels = n_channels + res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) + skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i]) + if i < n_layers - 1: + res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i]) + res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight])) + res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias])) + else: + res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight) + res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + wavenet.res_skip_layers.append(res_skip_layer) + del wavenet.res_layers + del wavenet.skip_layers + +def _update_model_cond(old_model, new_model): + for idx in range(0, len(new_model.WN)): + wavenet = new_model.WN[idx] + n_channels = wavenet.n_channels + n_layers = wavenet.n_layers + n_mel_channels = wavenet.cond_layers[0].weight.shape[1] + cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1) + cond_layer_weight = [] + cond_layer_bias = [] + for i in range(0, n_layers): + _cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layers[i]) + cond_layer_weight.append(_cond_layer.weight) + cond_layer_bias.append(_cond_layer.bias) + cond_layer.weight = torch.nn.Parameter(torch.cat(cond_layer_weight)) + cond_layer.bias = torch.nn.Parameter(torch.cat(cond_layer_bias)) + cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + wavenet.cond_layer = cond_layer + del wavenet.cond_layers + +def update_model(old_model): + if not _check_model_old_version(old_model): + return old_model + new_model = copy.deepcopy(old_model) + if hasattr(old_model.WN[0], 'res_layers'): + _update_model_res_skip(old_model, new_model) + if hasattr(old_model.WN[0], 'cond_layers'): + _update_model_cond(old_model, new_model) + return new_model + +if __name__ == '__main__': + old_model_path = sys.argv[1] + new_model_path = sys.argv[2] + model = torch.load(old_model_path) + model['model'] = update_model(model['model']) + torch.save(model, new_model_path) + diff --git a/taco2/data_utils.py b/taco2/data_utils.py index b845d94..b63dd89 100644 --- a/taco2/data_utils.py +++ b/taco2/data_utils.py @@ -7,6 +7,7 @@ import torch.utils.data from . import layers from .utils import load_wav_to_torch, load_filepaths_and_text from .text import text_to_sequence +# from text_codec import text_to_sequence class TextMelLoader(torch.utils.data.Dataset): diff --git a/taco2/denoiser.py b/taco2/denoiser.py new file mode 100644 index 0000000..3d3f45d --- /dev/null +++ b/taco2/denoiser.py @@ -0,0 +1,39 @@ +import sys +import torch +from .layers import STFT + + +class Denoiser(torch.nn.Module): + """ Removes model bias from audio produced with waveglow """ + + def __init__(self, waveglow, filter_length=1024, n_overlap=4, + win_length=1024, mode='zeros'): + super(Denoiser, self).__init__() + self.stft = STFT(filter_length=filter_length, + hop_length=int(filter_length/n_overlap), + win_length=win_length).cpu() + if mode == 'zeros': + mel_input = torch.zeros( + (1, 80, 88), + dtype=waveglow.upsample.weight.dtype, + device=waveglow.upsample.weight.device) + elif mode == 'normal': + mel_input = torch.randn( + (1, 80, 88), + dtype=waveglow.upsample.weight.dtype, + device=waveglow.upsample.weight.device) + else: + raise Exception("Mode {} if not supported".format(mode)) + + with torch.no_grad(): + bias_audio = waveglow.infer(mel_input, sigma=0.0).float() + bias_spec, _ = self.stft.transform(bias_audio) + + self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) + + def forward(self, audio, strength=0.1): + audio_spec, audio_angles = self.stft.transform(audio.cpu().float()) + audio_spec_denoised = audio_spec - self.bias_spec * strength + audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) + audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) + return audio_denoised diff --git a/taco2/hparams.py b/taco2/hparams.py index 1a126a7..255f5e2 100644 --- a/taco2/hparams.py +++ b/taco2/hparams.py @@ -2,6 +2,7 @@ # import tensorflow as tf from dataclasses import dataclass from .text import symbols +# from .text_codec import symbols @dataclass class HParams(object): diff --git a/taco2/text_codec.py b/taco2/text_codec.py new file mode 100644 index 0000000..f3da3c4 --- /dev/null +++ b/taco2/text_codec.py @@ -0,0 +1,81 @@ +from .utils import load_filepaths_and_text + +from .text import text_to_sequence, sequence_to_text + + +import sentencepiece as spm +from .text import symbols +from bpemb import BPEmb + + +SPM_CORPUS_FILE = "filelists/text_corpus.txt" +SPM_MODEL_PREFIX = "spm" +SPM_VOCAB_SIZE = 1000 + + +def _create_sentencepiece_corpus(): + from .hparams import HParams + hparams = HParams() + def get_text_list(text_file): + return [i[1] + "\n" for i in load_filepaths_and_text(text_file)] + + full_text_list = get_text_list(hparams.training_files) + get_text_list( + hparams.validation_files + ) + with open(SPM_CORPUS_FILE, "w") as fd: + fd.writelines(full_text_list) + + +def _create_sentencepiece_vocab(vocab_size=SPM_VOCAB_SIZE): + train_params = "--input={} --model_type=unigram --character_coverage=1.0 --model_prefix={} --vocab_size={}".format( + SPM_CORPUS_FILE, SPM_MODEL_PREFIX, vocab_size + ) + spm.SentencePieceTrainer.Train(train_params) + + +def _spm_text_codecs(): + sp = spm.SentencePieceProcessor() + sp.Load("{}.model".format(SPM_MODEL_PREFIX)) + + def ttseq(text, cleaners): + return sp.EncodeAsIds(text) + + def seqtt(sequence): + return sp.DecodeIds(sequence) + + return ttseq, seqtt + + +def _bpemb_text_codecs(): + global bpemb_en + bpemb_en = BPEmb(lang="en", dim=50, vs=1000) + def ttseq(text, cleaners): + return bpemb_en.encode_ids(text) + + def seqtt(sequence): + return bpemb_en.decode_ids(sequence) + + return ttseq, seqtt + +# text_to_sequence, sequence_to_text = _spm_text_codecs() +text_to_sequence, sequence_to_text = _bpemb_text_codecs() +symbols = bpemb_en.words + +def _interactive_test(): + from .hparams import HParams + hparams = HParams() + prompt = "Hello world; how are you, doing ?" + while prompt not in ["q", "quit"]: + oup = sequence_to_text(text_to_sequence(prompt, hparams.text_cleaners)) + print('==> ',oup) + prompt = input("> ") + + +def main(): + # _create_sentencepiece_corpus() + # _create_sentencepiece_vocab() + _interactive_test() + + +if __name__ == "__main__": + main() diff --git a/taco2/tts.py b/taco2/tts.py index b1e41a9..e96afb2 100644 --- a/taco2/tts.py +++ b/taco2/tts.py @@ -11,6 +11,7 @@ from .model import Tacotron2 from glow import WaveGlow from .hparams import HParams from .text import text_to_sequence +from .denoiser import Denoiser TTS_SAMPLE_RATE = 22050 OUTPUT_SAMPLE_RATE = 16000 @@ -18,7 +19,7 @@ OUTPUT_SAMPLE_RATE = 16000 # config from # https://github.com/NVIDIA/waveglow/blob/master/config.json WAVEGLOW_CONFIG = { - "n_mel_channels": 80, + "n_mel_channels": 40, "n_flows": 12, "n_group": 8, "n_early_every": 4, @@ -40,26 +41,36 @@ class TTSModel(object): ) self.model.eval() wave_params = torch.load(waveglow_path, map_location="cpu") - self.waveglow = WaveGlow(**WAVEGLOW_CONFIG) - self.waveglow.load_state_dict(wave_params) - self.waveglow.eval() - for k in self.waveglow.convinv: - k.float() - self.k_cache = klepto.archives.file_archive(cached=False) - self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(self.synth_speech) + try: + self.waveglow = WaveGlow(**WAVEGLOW_CONFIG) + self.waveglow.load_state_dict(wave_params) + self.waveglow.eval() + except: + self.waveglow = wave_params['model'] + self.waveglow = self.waveglow.remove_weightnorm(self.waveglow) + self.waveglow.eval() # workaround from # https://github.com/NVIDIA/waveglow/issues/127 for m in self.waveglow.modules(): if "Conv" in str(type(m)): setattr(m, "padding_mode", "zeros") + for k in self.waveglow.convinv: + k.float() + self.k_cache = klepto.archives.file_archive(cached=False) + self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(self.synth_speech) + self.denoiser = Denoiser(self.waveglow) - def synth_speech(self, t): - text = t + + def synth_speech(self, text): sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :] sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long() mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence) + # width = mel_outputs_postnet.shape[2] + # wave_glow_input = torch.randn(1, 80, width)*0.00001 + # wave_glow_input[:,40:,:] = mel_outputs_postnet with torch.no_grad(): audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666) + audio_t = self.denoiser(audio_t, 0.1)[0] audio = audio_t[0].data.cpu().numpy() # data = convert(audio) slow_data = time_stretch(audio, 0.8) @@ -67,7 +78,30 @@ class TTSModel(object): data = float2pcm(float_data) return data.tobytes() + def synth_speech_algo(self,text,griffin_iters=60): + sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :] + sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long() + mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence) + from .hparams import HParams + from .layers import TacotronSTFT + from .audio_processing import griffin_lim + hparams = HParams() + taco_stft = TacotronSTFT(hparams.filter_length, hparams.hop_length, hparams.win_length, n_mel_channels=hparams.n_mel_channels, sampling_rate=hparams.sampling_rate, mel_fmax=4000) + mel_decompress = taco_stft.spectral_de_normalize(mel_outputs_postnet) + mel_decompress = mel_decompress.transpose(1, 2).data.cpu() + spec_from_mel_scaling = 1000 + spec_from_mel = torch.mm(mel_decompress[0], taco_stft.mel_basis) + spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) + spec_from_mel = spec_from_mel * spec_from_mel_scaling + audio = griffin_lim(torch.autograd.Variable(spec_from_mel[:, :, :-1]), taco_stft.stft_fn, griffin_iters) + audio = audio.squeeze() + audio = audio.cpu().numpy() + + slow_data = time_stretch(audio, 0.8) + float_data = resample(slow_data, TTS_SAMPLE_RATE, OUTPUT_SAMPLE_RATE) + data = float2pcm(float_data) + return data.tobytes() # adapted from # https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py def float2pcm(sig, dtype="int16"): @@ -140,9 +174,14 @@ def synthesize_corpus(): def repl(): tts_model = TTSModel( - "/Users/malar/Work/tacotron2_statedict.pt", - # "/Users/malar/Work/waveglow_256channels.pt", - "/Users/malar/Work/waveglow.pt", + # "/Users/malar/Work/tacotron2_statedict.pt", + # "/Users/malar/Work/tacotron2_80_22000.pt", + "/Users/malar/Work/tacotron2_80_66000.pt", + # "/Users/malar/Work/tacotron2_40_22000.pt", + # "/Users/malar/Work/tacotron2_16000.pt", + "/Users/malar/Work/waveglow_256converted.pt", + # "/Users/malar/Work/waveglow.pt", + # "/Users/malar/Work/waveglow_38000", ) player = player_gen() def loop():