mirror of https://github.com/malarinv/tacotron2
enable gpu support if available
parent
5a30069f0a
commit
ac5ffcf6d5
72
glow.py
72
glow.py
|
|
@ -29,13 +29,14 @@ import torch
|
|||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
|
||||
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a+input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
t_act = torch.nn.functional.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.nn.functional.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
|
@ -90,7 +91,7 @@ class Invertible1x1Conv(torch.nn.Module):
|
|||
# Reverse computation
|
||||
W_inverse = W.float().inverse()
|
||||
W_inverse = Variable(W_inverse[..., None])
|
||||
if z.type() == 'torch.HalfTensor':
|
||||
if z.type() == 'torch.cuda.HalfTensor' or z.type() == 'torch.HalfTensor':
|
||||
W_inverse = W_inverse.half()
|
||||
self.W_inverse = W_inverse
|
||||
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
|
||||
|
|
@ -117,6 +118,7 @@ class WN(torch.nn.Module):
|
|||
self.n_channels = n_channels
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.cond_layers = torch.nn.ModuleList()
|
||||
|
||||
start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
|
||||
start = torch.nn.utils.weight_norm(start, name='weight')
|
||||
|
|
@ -129,9 +131,6 @@ class WN(torch.nn.Module):
|
|||
end.bias.data.zero_()
|
||||
self.end = end
|
||||
|
||||
cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
||||
|
||||
for i in range(n_layers):
|
||||
dilation = 2 ** i
|
||||
padding = int((kernel_size*dilation - dilation)/2)
|
||||
|
|
@ -140,6 +139,9 @@ class WN(torch.nn.Module):
|
|||
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1)
|
||||
cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
||||
self.cond_layers.append(cond_layer)
|
||||
|
||||
# last one is not necessary
|
||||
if i < n_layers - 1:
|
||||
|
|
@ -153,25 +155,24 @@ class WN(torch.nn.Module):
|
|||
def forward(self, forward_input):
|
||||
audio, spect = forward_input
|
||||
audio = self.start(audio)
|
||||
output = torch.zeros_like(audio)
|
||||
n_channels_tensor = torch.IntTensor([self.n_channels])
|
||||
|
||||
spect = self.cond_layer(spect)
|
||||
|
||||
for i in range(self.n_layers):
|
||||
spect_offset = i*2*self.n_channels
|
||||
acts = fused_add_tanh_sigmoid_multiply(
|
||||
self.in_layers[i](audio),
|
||||
spect[:,spect_offset:spect_offset+2*self.n_channels,:],
|
||||
n_channels_tensor)
|
||||
self.cond_layers[i](spect),
|
||||
torch.IntTensor([self.n_channels]))
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
audio = audio + res_skip_acts[:,:self.n_channels,:]
|
||||
output = output + res_skip_acts[:,self.n_channels:,:]
|
||||
audio = res_skip_acts[:,:self.n_channels,:] + audio
|
||||
skip_acts = res_skip_acts[:,self.n_channels:,:]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
skip_acts = res_skip_acts
|
||||
|
||||
if i == 0:
|
||||
output = skip_acts
|
||||
else:
|
||||
output = skip_acts + output
|
||||
return self.end(output)
|
||||
|
||||
|
||||
|
|
@ -257,14 +258,24 @@ class WaveGlow(torch.nn.Module):
|
|||
spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
|
||||
spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
|
||||
|
||||
if spect.type() == 'torch.HalfTensor':
|
||||
audio = torch.HalfTensor(spect.size(0),
|
||||
self.n_remaining_channels,
|
||||
spect.size(2)).normal_()
|
||||
if torch.cuda.is_available():
|
||||
if spect.type() == 'torch.cuda.HalfTensor':
|
||||
audio = torch.cuda.HalfTensor(spect.size(0),
|
||||
self.n_remaining_channels,
|
||||
spect.size(2)).normal_()
|
||||
else:
|
||||
audio = torch.cuda.FloatTensor(spect.size(0),
|
||||
self.n_remaining_channels,
|
||||
spect.size(2)).normal_()
|
||||
else:
|
||||
audio = torch.FloatTensor(spect.size(0),
|
||||
self.n_remaining_channels,
|
||||
spect.size(2)).normal_()
|
||||
if spect.type() == 'torch.HalfTensor':
|
||||
audio = torch.HalfTensor(spect.size(0),
|
||||
self.n_remaining_channels,
|
||||
spect.size(2)).normal_()
|
||||
else:
|
||||
audio = torch.FloatTensor(spect.size(0),
|
||||
self.n_remaining_channels,
|
||||
spect.size(2)).normal_()
|
||||
|
||||
audio = torch.autograd.Variable(sigma*audio)
|
||||
|
||||
|
|
@ -274,7 +285,6 @@ class WaveGlow(torch.nn.Module):
|
|||
audio_1 = audio[:,n_half:,:]
|
||||
|
||||
output = self.WN[k]((audio_0, spect))
|
||||
|
||||
s = output[:, n_half:, :]
|
||||
b = output[:, :n_half, :]
|
||||
audio_1 = (audio_1 - b)/torch.exp(s)
|
||||
|
|
@ -283,10 +293,16 @@ class WaveGlow(torch.nn.Module):
|
|||
audio = self.convinv[k](audio, reverse=True)
|
||||
|
||||
if k % self.n_early_every == 0 and k > 0:
|
||||
if spect.type() == 'torch.HalfTensor':
|
||||
z = torch.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
|
||||
if torch.cuda.is_available():
|
||||
if spect.type() == 'torch.cuda.HalfTensor':
|
||||
z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
|
||||
else:
|
||||
z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
|
||||
else:
|
||||
z = torch.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
|
||||
if spect.type() == 'torch.HalfTensor':
|
||||
z = torch.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
|
||||
else:
|
||||
z = torch.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
|
||||
audio = torch.cat((sigma*z, audio),1)
|
||||
|
||||
audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
|
||||
|
|
@ -298,7 +314,7 @@ class WaveGlow(torch.nn.Module):
|
|||
for WN in waveglow.WN:
|
||||
WN.start = torch.nn.utils.remove_weight_norm(WN.start)
|
||||
WN.in_layers = remove(WN.in_layers)
|
||||
WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer)
|
||||
WN.cond_layers = remove(WN.cond_layers)
|
||||
WN.res_skip_layers = remove(WN.res_skip_layers)
|
||||
return waveglow
|
||||
|
||||
|
|
|
|||
|
|
@ -9,9 +9,14 @@ class Denoiser(torch.nn.Module):
|
|||
def __init__(self, waveglow, filter_length=1024, n_overlap=4,
|
||||
win_length=1024, mode='zeros', n_mel_channels=80,):
|
||||
super(Denoiser, self).__init__()
|
||||
self.stft = STFT(filter_length=filter_length,
|
||||
hop_length=int(filter_length/n_overlap),
|
||||
win_length=win_length).cpu()
|
||||
if torch.cuda.is_available():
|
||||
self.stft = STFT(filter_length=filter_length,
|
||||
hop_length=int(filter_length/n_overlap),
|
||||
win_length=win_length).cuda()
|
||||
else:
|
||||
self.stft = STFT(filter_length=filter_length,
|
||||
hop_length=int(filter_length/n_overlap),
|
||||
win_length=win_length).cpu()
|
||||
if mode == 'zeros':
|
||||
mel_input = torch.zeros(
|
||||
(1, n_mel_channels, 88),
|
||||
|
|
@ -32,7 +37,10 @@ class Denoiser(torch.nn.Module):
|
|||
self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
|
||||
|
||||
def forward(self, audio, strength=0.1):
|
||||
audio_spec, audio_angles = self.stft.transform(audio.cpu().float())
|
||||
if torch.cuda.is_available():
|
||||
audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
|
||||
else:
|
||||
audio_spec, audio_angles = self.stft.transform(audio.cpu().float())
|
||||
audio_spec_denoised = audio_spec - self.bias_spec * strength
|
||||
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
|
||||
audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
|
||||
|
|
|
|||
|
|
@ -39,9 +39,9 @@ class HParams(object):
|
|||
filter_length = 1024
|
||||
hop_length = 256
|
||||
win_length = 1024
|
||||
n_mel_channels: int = 40
|
||||
n_mel_channels: int = 80
|
||||
mel_fmin: float = 0.0
|
||||
mel_fmax: float = 4000.0
|
||||
mel_fmax: float = 8000.0
|
||||
################################
|
||||
# Model Parameters #
|
||||
################################
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from scipy.signal import get_window
|
|||
from librosa.util import pad_center, tiny
|
||||
from .audio_processing import window_sumsquare
|
||||
|
||||
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
class STFT(torch.nn.Module):
|
||||
"""
|
||||
|
|
@ -147,9 +148,7 @@ class STFT(torch.nn.Module):
|
|||
window_sum = torch.autograd.Variable(
|
||||
torch.from_numpy(window_sum), requires_grad=False
|
||||
)
|
||||
# window_sum = window_sum.cuda() if magnitude.is_cuda else
|
||||
# window_sum
|
||||
# initially not commented out
|
||||
window_sum = window_sum.to(DEVICE)
|
||||
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
||||
approx_nonzero_indices
|
||||
]
|
||||
|
|
|
|||
34
taco2/tts.py
34
taco2/tts.py
|
|
@ -39,20 +39,29 @@ class TTSModel(object):
|
|||
hparams = HParams(**kwargs)
|
||||
hparams.sampling_rate = TTS_SAMPLE_RATE
|
||||
self.model = Tacotron2(hparams)
|
||||
self.model.load_state_dict(
|
||||
torch.load(tacotron2_path, map_location="cpu")["state_dict"]
|
||||
)
|
||||
self.model.eval()
|
||||
if torch.cuda.is_available():
|
||||
self.model.load_state_dict(torch.load(tacotron2_path)["state_dict"])
|
||||
self.model.cuda().eval()
|
||||
else:
|
||||
self.model.load_state_dict(
|
||||
torch.load(tacotron2_path, map_location="cpu")["state_dict"]
|
||||
)
|
||||
self.model.eval()
|
||||
self.k_cache = klepto.archives.file_archive(cached=False)
|
||||
if waveglow_path:
|
||||
wave_params = torch.load(waveglow_path, map_location="cpu")
|
||||
if torch.cuda.is_available():
|
||||
wave_params = torch.load(waveglow_path)
|
||||
else:
|
||||
wave_params = torch.load(waveglow_path, map_location="cpu")
|
||||
try:
|
||||
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
|
||||
self.waveglow.load_state_dict(wave_params)
|
||||
self.waveglow.eval()
|
||||
except:
|
||||
self.waveglow = wave_params["model"]
|
||||
self.waveglow = self.waveglow.remove_weightnorm(self.waveglow)
|
||||
if torch.cuda.is_available():
|
||||
self.waveglow.cuda().eval()
|
||||
else:
|
||||
self.waveglow.eval()
|
||||
# workaround from
|
||||
# https://github.com/NVIDIA/waveglow/issues/127
|
||||
|
|
@ -60,7 +69,7 @@ class TTSModel(object):
|
|||
if "Conv" in str(type(m)):
|
||||
setattr(m, "padding_mode", "zeros")
|
||||
for k in self.waveglow.convinv:
|
||||
k.float()
|
||||
k.float().half()
|
||||
self.denoiser = Denoiser(
|
||||
self.waveglow, n_mel_channels=hparams.n_mel_channels
|
||||
)
|
||||
|
|
@ -82,20 +91,27 @@ class TTSModel(object):
|
|||
|
||||
def generate_mel_postnet(self, text):
|
||||
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
|
||||
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
|
||||
if torch.cuda.is_available():
|
||||
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
|
||||
else:
|
||||
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
|
||||
with torch.no_grad():
|
||||
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(
|
||||
sequence
|
||||
)
|
||||
return mel_outputs_postnet
|
||||
|
||||
def synth_speech(self, text):
|
||||
def synth_speech_array(self, text):
|
||||
mel_outputs_postnet = self.generate_mel_postnet(text)
|
||||
|
||||
with torch.no_grad():
|
||||
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
|
||||
audio_t = self.denoiser(audio_t, 0.1)[0]
|
||||
audio = audio_t[0].data.cpu().numpy()
|
||||
return audio
|
||||
|
||||
def synth_speech(self, text):
|
||||
audio = self.synth_speech_array(text)
|
||||
|
||||
return postprocess_audio(
|
||||
audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE
|
||||
|
|
|
|||
|
|
@ -27,6 +27,6 @@ def load_filepaths_and_text(filename, split="|"):
|
|||
def to_gpu(x):
|
||||
x = x.contiguous()
|
||||
|
||||
# if torch.cuda.is_available(): #initially not commented out
|
||||
# x = x.cuda(non_blocking=True) # initially not commented out
|
||||
if torch.cuda.is_available(): #initially not commented out
|
||||
x = x.cuda(non_blocking=True) # initially not commented out
|
||||
return torch.autograd.Variable(x)
|
||||
|
|
|
|||
Loading…
Reference in New Issue