mirror of https://github.com/malarinv/tacotron2
1. add griffin lim support
2. add denoiser 3. add support to handle old and new waveglow modelsmaster
parent
a10a6d517e
commit
0682eddfdc
|
|
@ -0,0 +1,72 @@
|
||||||
|
import sys
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def _check_model_old_version(model):
|
||||||
|
if hasattr(model.WN[0], 'res_layers') or hasattr(model.WN[0], 'cond_layers'):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _update_model_res_skip(old_model, new_model):
|
||||||
|
for idx in range(0, len(new_model.WN)):
|
||||||
|
wavenet = new_model.WN[idx]
|
||||||
|
n_channels = wavenet.n_channels
|
||||||
|
n_layers = wavenet.n_layers
|
||||||
|
wavenet.res_skip_layers = torch.nn.ModuleList()
|
||||||
|
for i in range(0, n_layers):
|
||||||
|
if i < n_layers - 1:
|
||||||
|
res_skip_channels = 2*n_channels
|
||||||
|
else:
|
||||||
|
res_skip_channels = n_channels
|
||||||
|
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
|
||||||
|
skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i])
|
||||||
|
if i < n_layers - 1:
|
||||||
|
res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i])
|
||||||
|
res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight]))
|
||||||
|
res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias]))
|
||||||
|
else:
|
||||||
|
res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight)
|
||||||
|
res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias)
|
||||||
|
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
||||||
|
wavenet.res_skip_layers.append(res_skip_layer)
|
||||||
|
del wavenet.res_layers
|
||||||
|
del wavenet.skip_layers
|
||||||
|
|
||||||
|
def _update_model_cond(old_model, new_model):
|
||||||
|
for idx in range(0, len(new_model.WN)):
|
||||||
|
wavenet = new_model.WN[idx]
|
||||||
|
n_channels = wavenet.n_channels
|
||||||
|
n_layers = wavenet.n_layers
|
||||||
|
n_mel_channels = wavenet.cond_layers[0].weight.shape[1]
|
||||||
|
cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
|
||||||
|
cond_layer_weight = []
|
||||||
|
cond_layer_bias = []
|
||||||
|
for i in range(0, n_layers):
|
||||||
|
_cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layers[i])
|
||||||
|
cond_layer_weight.append(_cond_layer.weight)
|
||||||
|
cond_layer_bias.append(_cond_layer.bias)
|
||||||
|
cond_layer.weight = torch.nn.Parameter(torch.cat(cond_layer_weight))
|
||||||
|
cond_layer.bias = torch.nn.Parameter(torch.cat(cond_layer_bias))
|
||||||
|
cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
||||||
|
wavenet.cond_layer = cond_layer
|
||||||
|
del wavenet.cond_layers
|
||||||
|
|
||||||
|
def update_model(old_model):
|
||||||
|
if not _check_model_old_version(old_model):
|
||||||
|
return old_model
|
||||||
|
new_model = copy.deepcopy(old_model)
|
||||||
|
if hasattr(old_model.WN[0], 'res_layers'):
|
||||||
|
_update_model_res_skip(old_model, new_model)
|
||||||
|
if hasattr(old_model.WN[0], 'cond_layers'):
|
||||||
|
_update_model_cond(old_model, new_model)
|
||||||
|
return new_model
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
old_model_path = sys.argv[1]
|
||||||
|
new_model_path = sys.argv[2]
|
||||||
|
model = torch.load(old_model_path)
|
||||||
|
model['model'] = update_model(model['model'])
|
||||||
|
torch.save(model, new_model_path)
|
||||||
|
|
||||||
|
|
@ -7,6 +7,7 @@ import torch.utils.data
|
||||||
from . import layers
|
from . import layers
|
||||||
from .utils import load_wav_to_torch, load_filepaths_and_text
|
from .utils import load_wav_to_torch, load_filepaths_and_text
|
||||||
from .text import text_to_sequence
|
from .text import text_to_sequence
|
||||||
|
# from text_codec import text_to_sequence
|
||||||
|
|
||||||
|
|
||||||
class TextMelLoader(torch.utils.data.Dataset):
|
class TextMelLoader(torch.utils.data.Dataset):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
from .layers import STFT
|
||||||
|
|
||||||
|
|
||||||
|
class Denoiser(torch.nn.Module):
|
||||||
|
""" Removes model bias from audio produced with waveglow """
|
||||||
|
|
||||||
|
def __init__(self, waveglow, filter_length=1024, n_overlap=4,
|
||||||
|
win_length=1024, mode='zeros'):
|
||||||
|
super(Denoiser, self).__init__()
|
||||||
|
self.stft = STFT(filter_length=filter_length,
|
||||||
|
hop_length=int(filter_length/n_overlap),
|
||||||
|
win_length=win_length).cpu()
|
||||||
|
if mode == 'zeros':
|
||||||
|
mel_input = torch.zeros(
|
||||||
|
(1, 80, 88),
|
||||||
|
dtype=waveglow.upsample.weight.dtype,
|
||||||
|
device=waveglow.upsample.weight.device)
|
||||||
|
elif mode == 'normal':
|
||||||
|
mel_input = torch.randn(
|
||||||
|
(1, 80, 88),
|
||||||
|
dtype=waveglow.upsample.weight.dtype,
|
||||||
|
device=waveglow.upsample.weight.device)
|
||||||
|
else:
|
||||||
|
raise Exception("Mode {} if not supported".format(mode))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
bias_audio = waveglow.infer(mel_input, sigma=0.0).float()
|
||||||
|
bias_spec, _ = self.stft.transform(bias_audio)
|
||||||
|
|
||||||
|
self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
|
||||||
|
|
||||||
|
def forward(self, audio, strength=0.1):
|
||||||
|
audio_spec, audio_angles = self.stft.transform(audio.cpu().float())
|
||||||
|
audio_spec_denoised = audio_spec - self.bias_spec * strength
|
||||||
|
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
|
||||||
|
audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
|
||||||
|
return audio_denoised
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
# import tensorflow as tf
|
# import tensorflow as tf
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from .text import symbols
|
from .text import symbols
|
||||||
|
# from .text_codec import symbols
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HParams(object):
|
class HParams(object):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
from .utils import load_filepaths_and_text
|
||||||
|
|
||||||
|
from .text import text_to_sequence, sequence_to_text
|
||||||
|
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
from .text import symbols
|
||||||
|
from bpemb import BPEmb
|
||||||
|
|
||||||
|
|
||||||
|
SPM_CORPUS_FILE = "filelists/text_corpus.txt"
|
||||||
|
SPM_MODEL_PREFIX = "spm"
|
||||||
|
SPM_VOCAB_SIZE = 1000
|
||||||
|
|
||||||
|
|
||||||
|
def _create_sentencepiece_corpus():
|
||||||
|
from .hparams import HParams
|
||||||
|
hparams = HParams()
|
||||||
|
def get_text_list(text_file):
|
||||||
|
return [i[1] + "\n" for i in load_filepaths_and_text(text_file)]
|
||||||
|
|
||||||
|
full_text_list = get_text_list(hparams.training_files) + get_text_list(
|
||||||
|
hparams.validation_files
|
||||||
|
)
|
||||||
|
with open(SPM_CORPUS_FILE, "w") as fd:
|
||||||
|
fd.writelines(full_text_list)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_sentencepiece_vocab(vocab_size=SPM_VOCAB_SIZE):
|
||||||
|
train_params = "--input={} --model_type=unigram --character_coverage=1.0 --model_prefix={} --vocab_size={}".format(
|
||||||
|
SPM_CORPUS_FILE, SPM_MODEL_PREFIX, vocab_size
|
||||||
|
)
|
||||||
|
spm.SentencePieceTrainer.Train(train_params)
|
||||||
|
|
||||||
|
|
||||||
|
def _spm_text_codecs():
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.Load("{}.model".format(SPM_MODEL_PREFIX))
|
||||||
|
|
||||||
|
def ttseq(text, cleaners):
|
||||||
|
return sp.EncodeAsIds(text)
|
||||||
|
|
||||||
|
def seqtt(sequence):
|
||||||
|
return sp.DecodeIds(sequence)
|
||||||
|
|
||||||
|
return ttseq, seqtt
|
||||||
|
|
||||||
|
|
||||||
|
def _bpemb_text_codecs():
|
||||||
|
global bpemb_en
|
||||||
|
bpemb_en = BPEmb(lang="en", dim=50, vs=1000)
|
||||||
|
def ttseq(text, cleaners):
|
||||||
|
return bpemb_en.encode_ids(text)
|
||||||
|
|
||||||
|
def seqtt(sequence):
|
||||||
|
return bpemb_en.decode_ids(sequence)
|
||||||
|
|
||||||
|
return ttseq, seqtt
|
||||||
|
|
||||||
|
# text_to_sequence, sequence_to_text = _spm_text_codecs()
|
||||||
|
text_to_sequence, sequence_to_text = _bpemb_text_codecs()
|
||||||
|
symbols = bpemb_en.words
|
||||||
|
|
||||||
|
def _interactive_test():
|
||||||
|
from .hparams import HParams
|
||||||
|
hparams = HParams()
|
||||||
|
prompt = "Hello world; how are you, doing ?"
|
||||||
|
while prompt not in ["q", "quit"]:
|
||||||
|
oup = sequence_to_text(text_to_sequence(prompt, hparams.text_cleaners))
|
||||||
|
print('==> ',oup)
|
||||||
|
prompt = input("> ")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# _create_sentencepiece_corpus()
|
||||||
|
# _create_sentencepiece_vocab()
|
||||||
|
_interactive_test()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
59
taco2/tts.py
59
taco2/tts.py
|
|
@ -11,6 +11,7 @@ from .model import Tacotron2
|
||||||
from glow import WaveGlow
|
from glow import WaveGlow
|
||||||
from .hparams import HParams
|
from .hparams import HParams
|
||||||
from .text import text_to_sequence
|
from .text import text_to_sequence
|
||||||
|
from .denoiser import Denoiser
|
||||||
|
|
||||||
TTS_SAMPLE_RATE = 22050
|
TTS_SAMPLE_RATE = 22050
|
||||||
OUTPUT_SAMPLE_RATE = 16000
|
OUTPUT_SAMPLE_RATE = 16000
|
||||||
|
|
@ -18,7 +19,7 @@ OUTPUT_SAMPLE_RATE = 16000
|
||||||
# config from
|
# config from
|
||||||
# https://github.com/NVIDIA/waveglow/blob/master/config.json
|
# https://github.com/NVIDIA/waveglow/blob/master/config.json
|
||||||
WAVEGLOW_CONFIG = {
|
WAVEGLOW_CONFIG = {
|
||||||
"n_mel_channels": 80,
|
"n_mel_channels": 40,
|
||||||
"n_flows": 12,
|
"n_flows": 12,
|
||||||
"n_group": 8,
|
"n_group": 8,
|
||||||
"n_early_every": 4,
|
"n_early_every": 4,
|
||||||
|
|
@ -40,26 +41,36 @@ class TTSModel(object):
|
||||||
)
|
)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
wave_params = torch.load(waveglow_path, map_location="cpu")
|
wave_params = torch.load(waveglow_path, map_location="cpu")
|
||||||
|
try:
|
||||||
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
|
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
|
||||||
self.waveglow.load_state_dict(wave_params)
|
self.waveglow.load_state_dict(wave_params)
|
||||||
self.waveglow.eval()
|
self.waveglow.eval()
|
||||||
for k in self.waveglow.convinv:
|
except:
|
||||||
k.float()
|
self.waveglow = wave_params['model']
|
||||||
self.k_cache = klepto.archives.file_archive(cached=False)
|
self.waveglow = self.waveglow.remove_weightnorm(self.waveglow)
|
||||||
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(self.synth_speech)
|
self.waveglow.eval()
|
||||||
# workaround from
|
# workaround from
|
||||||
# https://github.com/NVIDIA/waveglow/issues/127
|
# https://github.com/NVIDIA/waveglow/issues/127
|
||||||
for m in self.waveglow.modules():
|
for m in self.waveglow.modules():
|
||||||
if "Conv" in str(type(m)):
|
if "Conv" in str(type(m)):
|
||||||
setattr(m, "padding_mode", "zeros")
|
setattr(m, "padding_mode", "zeros")
|
||||||
|
for k in self.waveglow.convinv:
|
||||||
|
k.float()
|
||||||
|
self.k_cache = klepto.archives.file_archive(cached=False)
|
||||||
|
self.synth_speech = klepto.safe.inf_cache(cache=self.k_cache)(self.synth_speech)
|
||||||
|
self.denoiser = Denoiser(self.waveglow)
|
||||||
|
|
||||||
def synth_speech(self, t):
|
|
||||||
text = t
|
def synth_speech(self, text):
|
||||||
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
|
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
|
||||||
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
|
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
|
||||||
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence)
|
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence)
|
||||||
|
# width = mel_outputs_postnet.shape[2]
|
||||||
|
# wave_glow_input = torch.randn(1, 80, width)*0.00001
|
||||||
|
# wave_glow_input[:,40:,:] = mel_outputs_postnet
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
|
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
|
||||||
|
audio_t = self.denoiser(audio_t, 0.1)[0]
|
||||||
audio = audio_t[0].data.cpu().numpy()
|
audio = audio_t[0].data.cpu().numpy()
|
||||||
# data = convert(audio)
|
# data = convert(audio)
|
||||||
slow_data = time_stretch(audio, 0.8)
|
slow_data = time_stretch(audio, 0.8)
|
||||||
|
|
@ -67,7 +78,30 @@ class TTSModel(object):
|
||||||
data = float2pcm(float_data)
|
data = float2pcm(float_data)
|
||||||
return data.tobytes()
|
return data.tobytes()
|
||||||
|
|
||||||
|
def synth_speech_algo(self,text,griffin_iters=60):
|
||||||
|
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
|
||||||
|
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
|
||||||
|
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence)
|
||||||
|
from .hparams import HParams
|
||||||
|
from .layers import TacotronSTFT
|
||||||
|
from .audio_processing import griffin_lim
|
||||||
|
hparams = HParams()
|
||||||
|
taco_stft = TacotronSTFT(hparams.filter_length, hparams.hop_length, hparams.win_length, n_mel_channels=hparams.n_mel_channels, sampling_rate=hparams.sampling_rate, mel_fmax=4000)
|
||||||
|
mel_decompress = taco_stft.spectral_de_normalize(mel_outputs_postnet)
|
||||||
|
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
|
||||||
|
spec_from_mel_scaling = 1000
|
||||||
|
spec_from_mel = torch.mm(mel_decompress[0], taco_stft.mel_basis)
|
||||||
|
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
|
||||||
|
spec_from_mel = spec_from_mel * spec_from_mel_scaling
|
||||||
|
|
||||||
|
audio = griffin_lim(torch.autograd.Variable(spec_from_mel[:, :, :-1]), taco_stft.stft_fn, griffin_iters)
|
||||||
|
audio = audio.squeeze()
|
||||||
|
audio = audio.cpu().numpy()
|
||||||
|
|
||||||
|
slow_data = time_stretch(audio, 0.8)
|
||||||
|
float_data = resample(slow_data, TTS_SAMPLE_RATE, OUTPUT_SAMPLE_RATE)
|
||||||
|
data = float2pcm(float_data)
|
||||||
|
return data.tobytes()
|
||||||
# adapted from
|
# adapted from
|
||||||
# https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py
|
# https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py
|
||||||
def float2pcm(sig, dtype="int16"):
|
def float2pcm(sig, dtype="int16"):
|
||||||
|
|
@ -140,9 +174,14 @@ def synthesize_corpus():
|
||||||
|
|
||||||
def repl():
|
def repl():
|
||||||
tts_model = TTSModel(
|
tts_model = TTSModel(
|
||||||
"/Users/malar/Work/tacotron2_statedict.pt",
|
# "/Users/malar/Work/tacotron2_statedict.pt",
|
||||||
# "/Users/malar/Work/waveglow_256channels.pt",
|
# "/Users/malar/Work/tacotron2_80_22000.pt",
|
||||||
"/Users/malar/Work/waveglow.pt",
|
"/Users/malar/Work/tacotron2_80_66000.pt",
|
||||||
|
# "/Users/malar/Work/tacotron2_40_22000.pt",
|
||||||
|
# "/Users/malar/Work/tacotron2_16000.pt",
|
||||||
|
"/Users/malar/Work/waveglow_256converted.pt",
|
||||||
|
# "/Users/malar/Work/waveglow.pt",
|
||||||
|
# "/Users/malar/Work/waveglow_38000",
|
||||||
)
|
)
|
||||||
player = player_gen()
|
player = player_gen()
|
||||||
def loop():
|
def loop():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue