1. add griffin lim support

2. add denoiser
3. add support to handle old and new waveglow models
master
Malar Kannan 2019-09-26 10:53:09 +05:30
parent a10a6d517e
commit 0682eddfdc
6 changed files with 246 additions and 13 deletions

72
taco2/convert_model.py Normal file
View File

@ -0,0 +1,72 @@
import sys
import copy
import torch
def _check_model_old_version(model):
if hasattr(model.WN[0], 'res_layers') or hasattr(model.WN[0], 'cond_layers'):
return True
else:
return False
def _update_model_res_skip(old_model, new_model):
for idx in range(0, len(new_model.WN)):
wavenet = new_model.WN[idx]
n_channels = wavenet.n_channels
n_layers = wavenet.n_layers
wavenet.res_skip_layers = torch.nn.ModuleList()
for i in range(0, n_layers):
if i < n_layers - 1:
res_skip_channels = 2*n_channels
else:
res_skip_channels = n_channels
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i])
if i < n_layers - 1:
res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i])
res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight]))
res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias]))
else:
res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight)
res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
wavenet.res_skip_layers.append(res_skip_layer)
del wavenet.res_layers
del wavenet.skip_layers
def _update_model_cond(old_model, new_model):
for idx in range(0, len(new_model.WN)):
wavenet = new_model.WN[idx]
n_channels = wavenet.n_channels
n_layers = wavenet.n_layers
n_mel_channels = wavenet.cond_layers[0].weight.shape[1]
cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
cond_layer_weight = []
cond_layer_bias = []
for i in range(0, n_layers):
_cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layers[i])
cond_layer_weight.append(_cond_layer.weight)
cond_layer_bias.append(_cond_layer.bias)
cond_layer.weight = torch.nn.Parameter(torch.cat(cond_layer_weight))
cond_layer.bias = torch.nn.Parameter(torch.cat(cond_layer_bias))
cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
wavenet.cond_layer = cond_layer
del wavenet.cond_layers
def update_model(old_model):
if not _check_model_old_version(old_model):
return old_model
new_model = copy.deepcopy(old_model)
if hasattr(old_model.WN[0], 'res_layers'):
_update_model_res_skip(old_model, new_model)
if hasattr(old_model.WN[0], 'cond_layers'):
_update_model_cond(old_model, new_model)
return new_model
if __name__ == '__main__':
old_model_path = sys.argv[1]
new_model_path = sys.argv[2]
model = torch.load(old_model_path)
model['model'] = update_model(model['model'])
torch.save(model, new_model_path)

View File

@ -7,6 +7,7 @@ import torch.utils.data
from . import layers
from .utils import load_wav_to_torch, load_filepaths_and_text
from .text import text_to_sequence
# from text_codec import text_to_sequence
class TextMelLoader(torch.utils.data.Dataset):

39
taco2/denoiser.py Normal file
View File

@ -0,0 +1,39 @@
import sys
import torch
from .layers import STFT
class Denoiser(torch.nn.Module):
""" Removes model bias from audio produced with waveglow """
def __init__(self, waveglow, filter_length=1024, n_overlap=4,
win_length=1024, mode='zeros'):
super(Denoiser, self).__init__()
self.stft = STFT(filter_length=filter_length,
hop_length=int(filter_length/n_overlap),
win_length=win_length).cpu()
if mode == 'zeros':
mel_input = torch.zeros(
(1, 80, 88),
dtype=waveglow.upsample.weight.dtype,
device=waveglow.upsample.weight.device)
elif mode == 'normal':
mel_input = torch.randn(
(1, 80, 88),
dtype=waveglow.upsample.weight.dtype,
device=waveglow.upsample.weight.device)
else:
raise Exception("Mode {} if not supported".format(mode))
with torch.no_grad():
bias_audio = waveglow.infer(mel_input, sigma=0.0).float()
bias_spec, _ = self.stft.transform(bias_audio)
self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
def forward(self, audio, strength=0.1):
audio_spec, audio_angles = self.stft.transform(audio.cpu().float())
audio_spec_denoised = audio_spec - self.bias_spec * strength
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
return audio_denoised

View File

@ -2,6 +2,7 @@
# import tensorflow as tf
from dataclasses import dataclass
from .text import symbols
# from .text_codec import symbols
@dataclass
class HParams(object):

81
taco2/text_codec.py Normal file
View File

@ -0,0 +1,81 @@
from .utils import load_filepaths_and_text
from .text import text_to_sequence, sequence_to_text
import sentencepiece as spm
from .text import symbols
from bpemb import BPEmb
SPM_CORPUS_FILE = "filelists/text_corpus.txt"
SPM_MODEL_PREFIX = "spm"
SPM_VOCAB_SIZE = 1000
def _create_sentencepiece_corpus():
from .hparams import HParams
hparams = HParams()
def get_text_list(text_file):
return [i[1] + "\n" for i in load_filepaths_and_text(text_file)]
full_text_list = get_text_list(hparams.training_files) + get_text_list(
hparams.validation_files
)
with open(SPM_CORPUS_FILE, "w") as fd:
fd.writelines(full_text_list)
def _create_sentencepiece_vocab(vocab_size=SPM_VOCAB_SIZE):
train_params = "--input={} --model_type=unigram --character_coverage=1.0 --model_prefix={} --vocab_size={}".format(
SPM_CORPUS_FILE, SPM_MODEL_PREFIX, vocab_size
)
spm.SentencePieceTrainer.Train(train_params)
def _spm_text_codecs():
sp = spm.SentencePieceProcessor()
sp.Load("{}.model".format(SPM_MODEL_PREFIX))
def ttseq(text, cleaners):
return sp.EncodeAsIds(text)
def seqtt(sequence):
return sp.DecodeIds(sequence)
return ttseq, seqtt
def _bpemb_text_codecs():
global bpemb_en
bpemb_en = BPEmb(lang="en", dim=50, vs=1000)
def ttseq(text, cleaners):
return bpemb_en.encode_ids(text)
def seqtt(sequence):
return bpemb_en.decode_ids(sequence)
return ttseq, seqtt
# text_to_sequence, sequence_to_text = _spm_text_codecs()
text_to_sequence, sequence_to_text = _bpemb_text_codecs()
symbols = bpemb_en.words
def _interactive_test():
from .hparams import HParams
hparams = HParams()
prompt = "Hello world; how are you, doing ?"
while prompt not in ["q", "quit"]:
oup = sequence_to_text(text_to_sequence(prompt, hparams.text_cleaners))
print('==> ',oup)
prompt = input("> ")
def main():
# _create_sentencepiece_corpus()
# _create_sentencepiece_vocab()
_interactive_test()
if __name__ == "__main__":
main()

View File

@ -11,6 +11,7 @@ from .model import Tacotron2
from glow import WaveGlow
from .hparams import HParams
from .text import text_to_sequence
from .denoiser import Denoiser
TTS_SAMPLE_RATE = 22050
OUTPUT_SAMPLE_RATE = 16000
@ -18,7 +19,7 @@ OUTPUT_SAMPLE_RATE = 16000
# config from
# https://github.com/NVIDIA/waveglow/blob/master/config.json
WAVEGLOW_CONFIG = {
"n_mel_channels": 80,
"n_mel_channels": 40,
"n_flows": 12,
"n_group": 8,
"n_early_every": 4,
@ -40,26 +41,36 @@ class TTSModel(object):
)
self.model.eval()
wave_params = torch.load(waveglow_path, map_location="cpu")
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
self.waveglow.load_state_dict(wave_params)
self.waveglow.eval()
for k in self.waveglow.convinv:
k.float()
self.k_cache = klepto.archives.file_archive(cached=False)
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(self.synth_speech)
try:
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
self.waveglow.load_state_dict(wave_params)
self.waveglow.eval()
except:
self.waveglow = wave_params['model']
self.waveglow = self.waveglow.remove_weightnorm(self.waveglow)
self.waveglow.eval()
# workaround from
# https://github.com/NVIDIA/waveglow/issues/127
for m in self.waveglow.modules():
if "Conv" in str(type(m)):
setattr(m, "padding_mode", "zeros")
for k in self.waveglow.convinv:
k.float()
self.k_cache = klepto.archives.file_archive(cached=False)
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(self.synth_speech)
self.denoiser = Denoiser(self.waveglow)
def synth_speech(self, t):
text = t
def synth_speech(self, text):
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence)
# width = mel_outputs_postnet.shape[2]
# wave_glow_input = torch.randn(1, 80, width)*0.00001
# wave_glow_input[:,40:,:] = mel_outputs_postnet
with torch.no_grad():
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
audio_t = self.denoiser(audio_t, 0.1)[0]
audio = audio_t[0].data.cpu().numpy()
# data = convert(audio)
slow_data = time_stretch(audio, 0.8)
@ -67,7 +78,30 @@ class TTSModel(object):
data = float2pcm(float_data)
return data.tobytes()
def synth_speech_algo(self,text,griffin_iters=60):
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence)
from .hparams import HParams
from .layers import TacotronSTFT
from .audio_processing import griffin_lim
hparams = HParams()
taco_stft = TacotronSTFT(hparams.filter_length, hparams.hop_length, hparams.win_length, n_mel_channels=hparams.n_mel_channels, sampling_rate=hparams.sampling_rate, mel_fmax=4000)
mel_decompress = taco_stft.spectral_de_normalize(mel_outputs_postnet)
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
spec_from_mel_scaling = 1000
spec_from_mel = torch.mm(mel_decompress[0], taco_stft.mel_basis)
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
spec_from_mel = spec_from_mel * spec_from_mel_scaling
audio = griffin_lim(torch.autograd.Variable(spec_from_mel[:, :, :-1]), taco_stft.stft_fn, griffin_iters)
audio = audio.squeeze()
audio = audio.cpu().numpy()
slow_data = time_stretch(audio, 0.8)
float_data = resample(slow_data, TTS_SAMPLE_RATE, OUTPUT_SAMPLE_RATE)
data = float2pcm(float_data)
return data.tobytes()
# adapted from
# https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py
def float2pcm(sig, dtype="int16"):
@ -140,9 +174,14 @@ def synthesize_corpus():
def repl():
tts_model = TTSModel(
"/Users/malar/Work/tacotron2_statedict.pt",
# "/Users/malar/Work/waveglow_256channels.pt",
"/Users/malar/Work/waveglow.pt",
# "/Users/malar/Work/tacotron2_statedict.pt",
# "/Users/malar/Work/tacotron2_80_22000.pt",
"/Users/malar/Work/tacotron2_80_66000.pt",
# "/Users/malar/Work/tacotron2_40_22000.pt",
# "/Users/malar/Work/tacotron2_16000.pt",
"/Users/malar/Work/waveglow_256converted.pt",
# "/Users/malar/Work/waveglow.pt",
# "/Users/malar/Work/waveglow_38000",
)
player = player_gen()
def loop():