enable gpu support if available

master
Malar Kannan 2019-11-27 21:42:10 +05:30
parent 5a30069f0a
commit ac5ffcf6d5
6 changed files with 87 additions and 48 deletions

72
glow.py
View File

@ -29,13 +29,14 @@ import torch
from torch.autograd import Variable from torch.autograd import Variable
import torch.nn.functional as F import torch.nn.functional as F
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
@torch.jit.script @torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0] n_channels_int = n_channels[0]
in_act = input_a+input_b in_act = input_a+input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :]) t_act = torch.nn.functional.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) s_act = torch.nn.functional.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act acts = t_act * s_act
return acts return acts
@ -90,7 +91,7 @@ class Invertible1x1Conv(torch.nn.Module):
# Reverse computation # Reverse computation
W_inverse = W.float().inverse() W_inverse = W.float().inverse()
W_inverse = Variable(W_inverse[..., None]) W_inverse = Variable(W_inverse[..., None])
if z.type() == 'torch.HalfTensor': if z.type() == 'torch.cuda.HalfTensor' or z.type() == 'torch.HalfTensor':
W_inverse = W_inverse.half() W_inverse = W_inverse.half()
self.W_inverse = W_inverse self.W_inverse = W_inverse
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
@ -117,6 +118,7 @@ class WN(torch.nn.Module):
self.n_channels = n_channels self.n_channels = n_channels
self.in_layers = torch.nn.ModuleList() self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList() self.res_skip_layers = torch.nn.ModuleList()
self.cond_layers = torch.nn.ModuleList()
start = torch.nn.Conv1d(n_in_channels, n_channels, 1) start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
start = torch.nn.utils.weight_norm(start, name='weight') start = torch.nn.utils.weight_norm(start, name='weight')
@ -129,9 +131,6 @@ class WN(torch.nn.Module):
end.bias.data.zero_() end.bias.data.zero_()
self.end = end self.end = end
cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
for i in range(n_layers): for i in range(n_layers):
dilation = 2 ** i dilation = 2 ** i
padding = int((kernel_size*dilation - dilation)/2) padding = int((kernel_size*dilation - dilation)/2)
@ -140,6 +139,9 @@ class WN(torch.nn.Module):
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
self.in_layers.append(in_layer) self.in_layers.append(in_layer)
cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1)
cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
self.cond_layers.append(cond_layer)
# last one is not necessary # last one is not necessary
if i < n_layers - 1: if i < n_layers - 1:
@ -153,25 +155,24 @@ class WN(torch.nn.Module):
def forward(self, forward_input): def forward(self, forward_input):
audio, spect = forward_input audio, spect = forward_input
audio = self.start(audio) audio = self.start(audio)
output = torch.zeros_like(audio)
n_channels_tensor = torch.IntTensor([self.n_channels])
spect = self.cond_layer(spect)
for i in range(self.n_layers): for i in range(self.n_layers):
spect_offset = i*2*self.n_channels
acts = fused_add_tanh_sigmoid_multiply( acts = fused_add_tanh_sigmoid_multiply(
self.in_layers[i](audio), self.in_layers[i](audio),
spect[:,spect_offset:spect_offset+2*self.n_channels,:], self.cond_layers[i](spect),
n_channels_tensor) torch.IntTensor([self.n_channels]))
res_skip_acts = self.res_skip_layers[i](acts) res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1: if i < self.n_layers - 1:
audio = audio + res_skip_acts[:,:self.n_channels,:] audio = res_skip_acts[:,:self.n_channels,:] + audio
output = output + res_skip_acts[:,self.n_channels:,:] skip_acts = res_skip_acts[:,self.n_channels:,:]
else: else:
output = output + res_skip_acts skip_acts = res_skip_acts
if i == 0:
output = skip_acts
else:
output = skip_acts + output
return self.end(output) return self.end(output)
@ -257,14 +258,24 @@ class WaveGlow(torch.nn.Module):
spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
if spect.type() == 'torch.HalfTensor': if torch.cuda.is_available():
audio = torch.HalfTensor(spect.size(0), if spect.type() == 'torch.cuda.HalfTensor':
self.n_remaining_channels, audio = torch.cuda.HalfTensor(spect.size(0),
spect.size(2)).normal_() self.n_remaining_channels,
spect.size(2)).normal_()
else:
audio = torch.cuda.FloatTensor(spect.size(0),
self.n_remaining_channels,
spect.size(2)).normal_()
else: else:
audio = torch.FloatTensor(spect.size(0), if spect.type() == 'torch.HalfTensor':
self.n_remaining_channels, audio = torch.HalfTensor(spect.size(0),
spect.size(2)).normal_() self.n_remaining_channels,
spect.size(2)).normal_()
else:
audio = torch.FloatTensor(spect.size(0),
self.n_remaining_channels,
spect.size(2)).normal_()
audio = torch.autograd.Variable(sigma*audio) audio = torch.autograd.Variable(sigma*audio)
@ -274,7 +285,6 @@ class WaveGlow(torch.nn.Module):
audio_1 = audio[:,n_half:,:] audio_1 = audio[:,n_half:,:]
output = self.WN[k]((audio_0, spect)) output = self.WN[k]((audio_0, spect))
s = output[:, n_half:, :] s = output[:, n_half:, :]
b = output[:, :n_half, :] b = output[:, :n_half, :]
audio_1 = (audio_1 - b)/torch.exp(s) audio_1 = (audio_1 - b)/torch.exp(s)
@ -283,10 +293,16 @@ class WaveGlow(torch.nn.Module):
audio = self.convinv[k](audio, reverse=True) audio = self.convinv[k](audio, reverse=True)
if k % self.n_early_every == 0 and k > 0: if k % self.n_early_every == 0 and k > 0:
if spect.type() == 'torch.HalfTensor': if torch.cuda.is_available():
z = torch.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() if spect.type() == 'torch.cuda.HalfTensor':
z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
else:
z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
else: else:
z = torch.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() if spect.type() == 'torch.HalfTensor':
z = torch.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
else:
z = torch.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
audio = torch.cat((sigma*z, audio),1) audio = torch.cat((sigma*z, audio),1)
audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
@ -298,7 +314,7 @@ class WaveGlow(torch.nn.Module):
for WN in waveglow.WN: for WN in waveglow.WN:
WN.start = torch.nn.utils.remove_weight_norm(WN.start) WN.start = torch.nn.utils.remove_weight_norm(WN.start)
WN.in_layers = remove(WN.in_layers) WN.in_layers = remove(WN.in_layers)
WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer) WN.cond_layers = remove(WN.cond_layers)
WN.res_skip_layers = remove(WN.res_skip_layers) WN.res_skip_layers = remove(WN.res_skip_layers)
return waveglow return waveglow

View File

@ -9,9 +9,14 @@ class Denoiser(torch.nn.Module):
def __init__(self, waveglow, filter_length=1024, n_overlap=4, def __init__(self, waveglow, filter_length=1024, n_overlap=4,
win_length=1024, mode='zeros', n_mel_channels=80,): win_length=1024, mode='zeros', n_mel_channels=80,):
super(Denoiser, self).__init__() super(Denoiser, self).__init__()
self.stft = STFT(filter_length=filter_length, if torch.cuda.is_available():
hop_length=int(filter_length/n_overlap), self.stft = STFT(filter_length=filter_length,
win_length=win_length).cpu() hop_length=int(filter_length/n_overlap),
win_length=win_length).cuda()
else:
self.stft = STFT(filter_length=filter_length,
hop_length=int(filter_length/n_overlap),
win_length=win_length).cpu()
if mode == 'zeros': if mode == 'zeros':
mel_input = torch.zeros( mel_input = torch.zeros(
(1, n_mel_channels, 88), (1, n_mel_channels, 88),
@ -32,7 +37,10 @@ class Denoiser(torch.nn.Module):
self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
def forward(self, audio, strength=0.1): def forward(self, audio, strength=0.1):
audio_spec, audio_angles = self.stft.transform(audio.cpu().float()) if torch.cuda.is_available():
audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
else:
audio_spec, audio_angles = self.stft.transform(audio.cpu().float())
audio_spec_denoised = audio_spec - self.bias_spec * strength audio_spec_denoised = audio_spec - self.bias_spec * strength
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)

View File

@ -39,9 +39,9 @@ class HParams(object):
filter_length = 1024 filter_length = 1024
hop_length = 256 hop_length = 256
win_length = 1024 win_length = 1024
n_mel_channels: int = 40 n_mel_channels: int = 80
mel_fmin: float = 0.0 mel_fmin: float = 0.0
mel_fmax: float = 4000.0 mel_fmax: float = 8000.0
################################ ################################
# Model Parameters # # Model Parameters #
################################ ################################

View File

@ -40,6 +40,7 @@ from scipy.signal import get_window
from librosa.util import pad_center, tiny from librosa.util import pad_center, tiny
from .audio_processing import window_sumsquare from .audio_processing import window_sumsquare
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class STFT(torch.nn.Module): class STFT(torch.nn.Module):
""" """
@ -147,9 +148,7 @@ class STFT(torch.nn.Module):
window_sum = torch.autograd.Variable( window_sum = torch.autograd.Variable(
torch.from_numpy(window_sum), requires_grad=False torch.from_numpy(window_sum), requires_grad=False
) )
# window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum = window_sum.to(DEVICE)
# window_sum
# initially not commented out
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
approx_nonzero_indices approx_nonzero_indices
] ]

View File

@ -39,20 +39,29 @@ class TTSModel(object):
hparams = HParams(**kwargs) hparams = HParams(**kwargs)
hparams.sampling_rate = TTS_SAMPLE_RATE hparams.sampling_rate = TTS_SAMPLE_RATE
self.model = Tacotron2(hparams) self.model = Tacotron2(hparams)
self.model.load_state_dict( if torch.cuda.is_available():
torch.load(tacotron2_path, map_location="cpu")["state_dict"] self.model.load_state_dict(torch.load(tacotron2_path)["state_dict"])
) self.model.cuda().eval()
self.model.eval() else:
self.model.load_state_dict(
torch.load(tacotron2_path, map_location="cpu")["state_dict"]
)
self.model.eval()
self.k_cache = klepto.archives.file_archive(cached=False) self.k_cache = klepto.archives.file_archive(cached=False)
if waveglow_path: if waveglow_path:
wave_params = torch.load(waveglow_path, map_location="cpu") if torch.cuda.is_available():
wave_params = torch.load(waveglow_path)
else:
wave_params = torch.load(waveglow_path, map_location="cpu")
try: try:
self.waveglow = WaveGlow(**WAVEGLOW_CONFIG) self.waveglow = WaveGlow(**WAVEGLOW_CONFIG)
self.waveglow.load_state_dict(wave_params) self.waveglow.load_state_dict(wave_params)
self.waveglow.eval()
except: except:
self.waveglow = wave_params["model"] self.waveglow = wave_params["model"]
self.waveglow = self.waveglow.remove_weightnorm(self.waveglow) self.waveglow = self.waveglow.remove_weightnorm(self.waveglow)
if torch.cuda.is_available():
self.waveglow.cuda().eval()
else:
self.waveglow.eval() self.waveglow.eval()
# workaround from # workaround from
# https://github.com/NVIDIA/waveglow/issues/127 # https://github.com/NVIDIA/waveglow/issues/127
@ -60,7 +69,7 @@ class TTSModel(object):
if "Conv" in str(type(m)): if "Conv" in str(type(m)):
setattr(m, "padding_mode", "zeros") setattr(m, "padding_mode", "zeros")
for k in self.waveglow.convinv: for k in self.waveglow.convinv:
k.float() k.float().half()
self.denoiser = Denoiser( self.denoiser = Denoiser(
self.waveglow, n_mel_channels=hparams.n_mel_channels self.waveglow, n_mel_channels=hparams.n_mel_channels
) )
@ -82,20 +91,27 @@ class TTSModel(object):
def generate_mel_postnet(self, text): def generate_mel_postnet(self, text):
sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :] sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long() if torch.cuda.is_available():
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
else:
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
with torch.no_grad(): with torch.no_grad():
mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference( mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(
sequence sequence
) )
return mel_outputs_postnet return mel_outputs_postnet
def synth_speech(self, text): def synth_speech_array(self, text):
mel_outputs_postnet = self.generate_mel_postnet(text) mel_outputs_postnet = self.generate_mel_postnet(text)
with torch.no_grad(): with torch.no_grad():
audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666) audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666)
audio_t = self.denoiser(audio_t, 0.1)[0] audio_t = self.denoiser(audio_t, 0.1)[0]
audio = audio_t[0].data.cpu().numpy() audio = audio_t[0].data.cpu().numpy()
return audio
def synth_speech(self, text):
audio = self.synth_speech_array(text)
return postprocess_audio( return postprocess_audio(
audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE

View File

@ -27,6 +27,6 @@ def load_filepaths_and_text(filename, split="|"):
def to_gpu(x): def to_gpu(x):
x = x.contiguous() x = x.contiguous()
# if torch.cuda.is_available(): #initially not commented out if torch.cuda.is_available(): #initially not commented out
# x = x.cuda(non_blocking=True) # initially not commented out x = x.cuda(non_blocking=True) # initially not commented out
return torch.autograd.Variable(x) return torch.autograd.Variable(x)