diff --git a/glow.py b/glow.py index 7c8e46c..b06c0c7 100644 --- a/glow.py +++ b/glow.py @@ -29,13 +29,14 @@ import torch from torch.autograd import Variable import torch.nn.functional as F +DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') @torch.jit.script def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): n_channels_int = n_channels[0] in_act = input_a+input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + t_act = torch.nn.functional.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.nn.functional.sigmoid(in_act[:, n_channels_int:, :]) acts = t_act * s_act return acts @@ -90,7 +91,7 @@ class Invertible1x1Conv(torch.nn.Module): # Reverse computation W_inverse = W.float().inverse() W_inverse = Variable(W_inverse[..., None]) - if z.type() == 'torch.HalfTensor': + if z.type() == 'torch.cuda.HalfTensor' or z.type() == 'torch.HalfTensor': W_inverse = W_inverse.half() self.W_inverse = W_inverse z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) @@ -117,6 +118,7 @@ class WN(torch.nn.Module): self.n_channels = n_channels self.in_layers = torch.nn.ModuleList() self.res_skip_layers = torch.nn.ModuleList() + self.cond_layers = torch.nn.ModuleList() start = torch.nn.Conv1d(n_in_channels, n_channels, 1) start = torch.nn.utils.weight_norm(start, name='weight') @@ -129,9 +131,6 @@ class WN(torch.nn.Module): end.bias.data.zero_() self.end = end - cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') - for i in range(n_layers): dilation = 2 ** i padding = int((kernel_size*dilation - dilation)/2) @@ -140,6 +139,9 @@ class WN(torch.nn.Module): in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') self.in_layers.append(in_layer) + cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1) + cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + self.cond_layers.append(cond_layer) # last one is not necessary if i < n_layers - 1: @@ -153,25 +155,24 @@ class WN(torch.nn.Module): def forward(self, forward_input): audio, spect = forward_input audio = self.start(audio) - output = torch.zeros_like(audio) - n_channels_tensor = torch.IntTensor([self.n_channels]) - - spect = self.cond_layer(spect) for i in range(self.n_layers): - spect_offset = i*2*self.n_channels acts = fused_add_tanh_sigmoid_multiply( self.in_layers[i](audio), - spect[:,spect_offset:spect_offset+2*self.n_channels,:], - n_channels_tensor) + self.cond_layers[i](spect), + torch.IntTensor([self.n_channels])) res_skip_acts = self.res_skip_layers[i](acts) if i < self.n_layers - 1: - audio = audio + res_skip_acts[:,:self.n_channels,:] - output = output + res_skip_acts[:,self.n_channels:,:] + audio = res_skip_acts[:,:self.n_channels,:] + audio + skip_acts = res_skip_acts[:,self.n_channels:,:] else: - output = output + res_skip_acts + skip_acts = res_skip_acts + if i == 0: + output = skip_acts + else: + output = skip_acts + output return self.end(output) @@ -257,14 +258,24 @@ class WaveGlow(torch.nn.Module): spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) - if spect.type() == 'torch.HalfTensor': - audio = torch.HalfTensor(spect.size(0), - self.n_remaining_channels, - spect.size(2)).normal_() + if torch.cuda.is_available(): + if spect.type() == 'torch.cuda.HalfTensor': + audio = torch.cuda.HalfTensor(spect.size(0), + self.n_remaining_channels, + spect.size(2)).normal_() + else: + audio = torch.cuda.FloatTensor(spect.size(0), + self.n_remaining_channels, + spect.size(2)).normal_() else: - audio = torch.FloatTensor(spect.size(0), - self.n_remaining_channels, - spect.size(2)).normal_() + if spect.type() == 'torch.HalfTensor': + audio = torch.HalfTensor(spect.size(0), + self.n_remaining_channels, + spect.size(2)).normal_() + else: + audio = torch.FloatTensor(spect.size(0), + self.n_remaining_channels, + spect.size(2)).normal_() audio = torch.autograd.Variable(sigma*audio) @@ -274,7 +285,6 @@ class WaveGlow(torch.nn.Module): audio_1 = audio[:,n_half:,:] output = self.WN[k]((audio_0, spect)) - s = output[:, n_half:, :] b = output[:, :n_half, :] audio_1 = (audio_1 - b)/torch.exp(s) @@ -283,10 +293,16 @@ class WaveGlow(torch.nn.Module): audio = self.convinv[k](audio, reverse=True) if k % self.n_early_every == 0 and k > 0: - if spect.type() == 'torch.HalfTensor': - z = torch.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() + if torch.cuda.is_available(): + if spect.type() == 'torch.cuda.HalfTensor': + z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() + else: + z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() else: - z = torch.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() + if spect.type() == 'torch.HalfTensor': + z = torch.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() + else: + z = torch.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() audio = torch.cat((sigma*z, audio),1) audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data @@ -298,7 +314,7 @@ class WaveGlow(torch.nn.Module): for WN in waveglow.WN: WN.start = torch.nn.utils.remove_weight_norm(WN.start) WN.in_layers = remove(WN.in_layers) - WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer) + WN.cond_layers = remove(WN.cond_layers) WN.res_skip_layers = remove(WN.res_skip_layers) return waveglow diff --git a/taco2/denoiser.py b/taco2/denoiser.py index de1836e..2989c30 100644 --- a/taco2/denoiser.py +++ b/taco2/denoiser.py @@ -9,9 +9,14 @@ class Denoiser(torch.nn.Module): def __init__(self, waveglow, filter_length=1024, n_overlap=4, win_length=1024, mode='zeros', n_mel_channels=80,): super(Denoiser, self).__init__() - self.stft = STFT(filter_length=filter_length, - hop_length=int(filter_length/n_overlap), - win_length=win_length).cpu() + if torch.cuda.is_available(): + self.stft = STFT(filter_length=filter_length, + hop_length=int(filter_length/n_overlap), + win_length=win_length).cuda() + else: + self.stft = STFT(filter_length=filter_length, + hop_length=int(filter_length/n_overlap), + win_length=win_length).cpu() if mode == 'zeros': mel_input = torch.zeros( (1, n_mel_channels, 88), @@ -32,7 +37,10 @@ class Denoiser(torch.nn.Module): self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) def forward(self, audio, strength=0.1): - audio_spec, audio_angles = self.stft.transform(audio.cpu().float()) + if torch.cuda.is_available(): + audio_spec, audio_angles = self.stft.transform(audio.cuda().float()) + else: + audio_spec, audio_angles = self.stft.transform(audio.cpu().float()) audio_spec_denoised = audio_spec - self.bias_spec * strength audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) diff --git a/taco2/hparams.py b/taco2/hparams.py index d123aea..6c632dc 100644 --- a/taco2/hparams.py +++ b/taco2/hparams.py @@ -39,9 +39,9 @@ class HParams(object): filter_length = 1024 hop_length = 256 win_length = 1024 - n_mel_channels: int = 40 + n_mel_channels: int = 80 mel_fmin: float = 0.0 - mel_fmax: float = 4000.0 + mel_fmax: float = 8000.0 ################################ # Model Parameters # ################################ diff --git a/taco2/stft.py b/taco2/stft.py index eeaa94d..827362b 100644 --- a/taco2/stft.py +++ b/taco2/stft.py @@ -40,6 +40,7 @@ from scipy.signal import get_window from librosa.util import pad_center, tiny from .audio_processing import window_sumsquare +DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') class STFT(torch.nn.Module): """ @@ -147,9 +148,7 @@ class STFT(torch.nn.Module): window_sum = torch.autograd.Variable( torch.from_numpy(window_sum), requires_grad=False ) - # window_sum = window_sum.cuda() if magnitude.is_cuda else - # window_sum - # initially not commented out + window_sum = window_sum.to(DEVICE) inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ approx_nonzero_indices ] diff --git a/taco2/tts.py b/taco2/tts.py index bfe57c6..3c51885 100644 --- a/taco2/tts.py +++ b/taco2/tts.py @@ -39,20 +39,29 @@ class TTSModel(object): hparams = HParams(**kwargs) hparams.sampling_rate = TTS_SAMPLE_RATE self.model = Tacotron2(hparams) - self.model.load_state_dict( - torch.load(tacotron2_path, map_location="cpu")["state_dict"] - ) - self.model.eval() + if torch.cuda.is_available(): + self.model.load_state_dict(torch.load(tacotron2_path)["state_dict"]) + self.model.cuda().eval() + else: + self.model.load_state_dict( + torch.load(tacotron2_path, map_location="cpu")["state_dict"] + ) + self.model.eval() self.k_cache = klepto.archives.file_archive(cached=False) if waveglow_path: - wave_params = torch.load(waveglow_path, map_location="cpu") + if torch.cuda.is_available(): + wave_params = torch.load(waveglow_path) + else: + wave_params = torch.load(waveglow_path, map_location="cpu") try: self.waveglow = WaveGlow(**WAVEGLOW_CONFIG) self.waveglow.load_state_dict(wave_params) - self.waveglow.eval() except: self.waveglow = wave_params["model"] self.waveglow = self.waveglow.remove_weightnorm(self.waveglow) + if torch.cuda.is_available(): + self.waveglow.cuda().eval() + else: self.waveglow.eval() # workaround from # https://github.com/NVIDIA/waveglow/issues/127 @@ -60,7 +69,7 @@ class TTSModel(object): if "Conv" in str(type(m)): setattr(m, "padding_mode", "zeros") for k in self.waveglow.convinv: - k.float() + k.float().half() self.denoiser = Denoiser( self.waveglow, n_mel_channels=hparams.n_mel_channels ) @@ -82,20 +91,27 @@ class TTSModel(object): def generate_mel_postnet(self, text): sequence = np.array(text_to_sequence(text, ["english_cleaners"]))[None, :] - sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long() + if torch.cuda.is_available(): + sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() + else: + sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long() with torch.no_grad(): mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference( sequence ) return mel_outputs_postnet - def synth_speech(self, text): + def synth_speech_array(self, text): mel_outputs_postnet = self.generate_mel_postnet(text) with torch.no_grad(): audio_t = self.waveglow.infer(mel_outputs_postnet, sigma=0.666) audio_t = self.denoiser(audio_t, 0.1)[0] audio = audio_t[0].data.cpu().numpy() + return audio + + def synth_speech(self, text): + audio = self.synth_speech_array(text) return postprocess_audio( audio, src_rate=TTS_SAMPLE_RATE, dst_rate=OUTPUT_SAMPLE_RATE diff --git a/taco2/utils.py b/taco2/utils.py index f779bf5..d53ca2e 100644 --- a/taco2/utils.py +++ b/taco2/utils.py @@ -27,6 +27,6 @@ def load_filepaths_and_text(filename, split="|"): def to_gpu(x): x = x.contiguous() - # if torch.cuda.is_available(): #initially not commented out - # x = x.cuda(non_blocking=True) # initially not commented out + if torch.cuda.is_available(): #initially not commented out + x = x.cuda(non_blocking=True) # initially not commented out return torch.autograd.Variable(x)