compatibility to pretrained modesl

Malar 2019-10-12 14:40:58 +05:30
parent 5a30069f0a
commit 342b230b93
3 changed files with 218 additions and 218 deletions

166
glow.py
View File

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# ***************************************************************************** # *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# #
@ -12,19 +13,18 @@
# names of its contributors may be used to endorse or promote products # names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission. # derived from this software without specific prior written permission.
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY # ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
# ***************************************************************************** # *****************************************************************************
import copy
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
import torch.nn.functional as F import torch.nn.functional as F
@ -34,8 +34,8 @@ import torch.nn.functional as F
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0] n_channels_int = n_channels[0]
in_act = input_a + input_b in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :]) t_act = torch.nn.functional.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) s_act = torch.nn.functional.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act acts = t_act * s_act
return acts return acts
@ -55,7 +55,11 @@ class WaveGlowLoss(torch.nn.Module):
log_s_total = log_s_total + torch.sum(log_s) log_s_total = log_s_total + torch.sum(log_s)
log_det_W_total += log_det_W_list[i] log_det_W_total += log_det_W_list[i]
loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total loss = (
torch.sum(z * z) / (2 * self.sigma * self.sigma)
- log_s_total
- log_det_W_total
)
return loss / (z.size(0) * z.size(1) * z.size(2)) return loss / (z.size(0) * z.size(1) * z.size(2))
@ -65,10 +69,12 @@ class Invertible1x1Conv(torch.nn.Module):
of its weight matrix. If reverse=True it does convolution with of its weight matrix. If reverse=True it does convolution with
inverse inverse
""" """
def __init__(self, c): def __init__(self, c):
super(Invertible1x1Conv, self).__init__() super(Invertible1x1Conv, self).__init__()
self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, self.conv = torch.nn.Conv1d(
bias=False) c, c, kernel_size=1, stride=1, padding=0, bias=False
)
# Sample a random orthonormal matrix to initialize weights # Sample a random orthonormal matrix to initialize weights
W = torch.qr(torch.FloatTensor(c, c).normal_())[0] W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
@ -86,11 +92,11 @@ class Invertible1x1Conv(torch.nn.Module):
W = self.conv.weight.squeeze() W = self.conv.weight.squeeze()
if reverse: if reverse:
if not hasattr(self, 'W_inverse'): if not hasattr(self, "W_inverse"):
# Reverse computation # Reverse computation
W_inverse = W.float().inverse() W_inverse = W.inverse()
W_inverse = Variable(W_inverse[..., None]) W_inverse = Variable(W_inverse[..., None])
if z.type() == 'torch.HalfTensor': if z.type() == "torch.cuda.HalfTensor":
W_inverse = W_inverse.half() W_inverse = W_inverse.half()
self.W_inverse = W_inverse self.W_inverse = W_inverse
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
@ -104,22 +110,25 @@ class Invertible1x1Conv(torch.nn.Module):
class WN(torch.nn.Module): class WN(torch.nn.Module):
""" """
This is the WaveNet like layer for the affine coupling. The primary difference This is the WaveNet like layer for the affine coupling. The primary
from WaveNet is the convolutions need not be causal. There is also no dilation difference from WaveNet is the convolutions need not be causal. There is
size reset. The dilation only doubles on each layer also no dilation size reset. The dilation only doubles on each layer
""" """
def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
kernel_size): def __init__(
self, n_in_channels, n_mel_channels, n_layers, n_channels, kernel_size
):
super(WN, self).__init__() super(WN, self).__init__()
assert(kernel_size % 2 == 1) assert kernel_size % 2 == 1
assert(n_channels % 2 == 0) assert n_channels % 2 == 0
self.n_layers = n_layers self.n_layers = n_layers
self.n_channels = n_channels self.n_channels = n_channels
self.in_layers = torch.nn.ModuleList() self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList() self.res_skip_layers = torch.nn.ModuleList()
self.cond_layers = torch.nn.ModuleList()
start = torch.nn.Conv1d(n_in_channels, n_channels, 1) start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
start = torch.nn.utils.weight_norm(start, name='weight') start = torch.nn.utils.weight_norm(start, name="weight")
self.start = start self.start = start
# Initializing last layer to 0 makes the affine coupling layers # Initializing last layer to 0 makes the affine coupling layers
@ -129,17 +138,22 @@ class WN(torch.nn.Module):
end.bias.data.zero_() end.bias.data.zero_()
self.end = end self.end = end
cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
for i in range(n_layers): for i in range(n_layers):
dilation = 2 ** i dilation = 2 ** i
padding = int((kernel_size * dilation - dilation) / 2) padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, in_layer = torch.nn.Conv1d(
dilation=dilation, padding=padding) n_channels,
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 2 * n_channels,
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer) self.in_layers.append(in_layer)
cond_layer = torch.nn.Conv1d(n_mel_channels, 2 * n_channels, 1)
cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
self.cond_layers.append(cond_layer)
# last one is not necessary # last one is not necessary
if i < n_layers - 1: if i < n_layers - 1:
@ -147,43 +161,51 @@ class WN(torch.nn.Module):
else: else:
res_skip_channels = n_channels res_skip_channels = n_channels
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') res_skip_layer = torch.nn.utils.weight_norm(
res_skip_layer, name="weight"
)
self.res_skip_layers.append(res_skip_layer) self.res_skip_layers.append(res_skip_layer)
def forward(self, forward_input): def forward(self, forward_input):
audio, spect = forward_input audio, spect = forward_input
audio = self.start(audio) audio = self.start(audio)
output = torch.zeros_like(audio)
n_channels_tensor = torch.IntTensor([self.n_channels])
spect = self.cond_layer(spect)
for i in range(self.n_layers): for i in range(self.n_layers):
spect_offset = i*2*self.n_channels
acts = fused_add_tanh_sigmoid_multiply( acts = fused_add_tanh_sigmoid_multiply(
self.in_layers[i](audio), self.in_layers[i](audio),
spect[:,spect_offset:spect_offset+2*self.n_channels,:], self.cond_layers[i](spect),
n_channels_tensor) torch.IntTensor([self.n_channels]),
)
res_skip_acts = self.res_skip_layers[i](acts) res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1: if i < self.n_layers - 1:
audio = audio + res_skip_acts[:,:self.n_channels,:] audio = res_skip_acts[:, : self.n_channels, :] + audio
output = output + res_skip_acts[:,self.n_channels:,:] skip_acts = res_skip_acts[:, self.n_channels :, :]
else: else:
output = output + res_skip_acts skip_acts = res_skip_acts
if i == 0:
output = skip_acts
else:
output = skip_acts + output
return self.end(output) return self.end(output)
class WaveGlow(torch.nn.Module): class WaveGlow(torch.nn.Module):
def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, def __init__(
n_early_size, WN_config): self,
n_mel_channels,
n_flows,
n_group,
n_early_every,
n_early_size,
WN_config,
):
super(WaveGlow, self).__init__() super(WaveGlow, self).__init__()
self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, self.upsample = torch.nn.ConvTranspose1d(
n_mel_channels, n_mel_channels, n_mel_channels, 1024, stride=256
1024, stride=256) )
assert(n_group % 2 == 0) assert n_group % 2 == 0
self.n_flows = n_flows self.n_flows = n_flows
self.n_group = n_group self.n_group = n_group
self.n_early_every = n_early_every self.n_early_every = n_early_every
@ -202,7 +224,8 @@ class WaveGlow(torch.nn.Module):
n_remaining_channels = n_remaining_channels - self.n_early_size n_remaining_channels = n_remaining_channels - self.n_early_size
self.convinv.append(Invertible1x1Conv(n_remaining_channels)) self.convinv.append(Invertible1x1Conv(n_remaining_channels))
self.WN.append(WN(n_half, n_mel_channels * n_group, **WN_config)) self.WN.append(WN(n_half, n_mel_channels * n_group, **WN_config))
self.n_remaining_channels = n_remaining_channels # Useful during inference self.n_remaining_channels = n_remaining_channels
# Useful during inference
def forward(self, forward_input): def forward(self, forward_input):
""" """
@ -213,12 +236,16 @@ class WaveGlow(torch.nn.Module):
# Upsample spectrogram to size of audio # Upsample spectrogram to size of audio
spect = self.upsample(spect) spect = self.upsample(spect)
assert(spect.size(2) >= audio.size(1)) assert spect.size(2) >= audio.size(1)
if spect.size(2) > audio.size(1): if spect.size(2) > audio.size(1):
spect = spect[:, :, : audio.size(1)] spect = spect[:, :, : audio.size(1)]
spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) spect = (
spect.contiguous()
.view(spect.size(0), spect.size(1), -1)
.permute(0, 2, 1)
)
audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
output_audio = [] output_audio = []
@ -255,16 +282,21 @@ class WaveGlow(torch.nn.Module):
spect = spect[:, :, :-time_cutoff] spect = spect[:, :, :-time_cutoff]
spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) spect = (
spect.contiguous()
.view(spect.size(0), spect.size(1), -1)
.permute(0, 2, 1)
)
if spect.type() == 'torch.HalfTensor': if spect.type() == "torch.cuda.HalfTensor":
audio = torch.HalfTensor(spect.size(0), audio = torch.cuda.HalfTensor(
self.n_remaining_channels, spect.size(0), self.n_remaining_channels, spect.size(2)
spect.size(2)).normal_() ).normal_()
else: else:
audio = torch.FloatTensor(spect.size(0), # cuda.FloatTensor -> FloatTensor
self.n_remaining_channels, audio = torch.FloatTensor(
spect.size(2)).normal_() spect.size(0), self.n_remaining_channels, spect.size(2)
).normal_()
audio = torch.autograd.Variable(sigma * audio) audio = torch.autograd.Variable(sigma * audio)
@ -274,7 +306,6 @@ class WaveGlow(torch.nn.Module):
audio_1 = audio[:, n_half:, :] audio_1 = audio[:, n_half:, :]
output = self.WN[k]((audio_0, spect)) output = self.WN[k]((audio_0, spect))
s = output[:, n_half:, :] s = output[:, n_half:, :]
b = output[:, :n_half, :] b = output[:, :n_half, :]
audio_1 = (audio_1 - b) / torch.exp(s) audio_1 = (audio_1 - b) / torch.exp(s)
@ -283,13 +314,20 @@ class WaveGlow(torch.nn.Module):
audio = self.convinv[k](audio, reverse=True) audio = self.convinv[k](audio, reverse=True)
if k % self.n_early_every == 0 and k > 0: if k % self.n_early_every == 0 and k > 0:
if spect.type() == 'torch.HalfTensor': if spect.type() == "torch.cuda.HalfTensor":
z = torch.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() z = torch.cuda.HalfTensor(
spect.size(0), self.n_early_size, spect.size(2)
).normal_()
else: else:
z = torch.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() # cuda.FloatTensor -> FloatTensor
z = torch.FloatTensor(
spect.size(0), self.n_early_size, spect.size(2)
).normal_()
audio = torch.cat((sigma * z, audio), 1) audio = torch.cat((sigma * z, audio), 1)
audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data audio = (
audio.permute(0, 2, 1).contiguous().view(audio.size(0), -1).data
)
return audio return audio
@staticmethod @staticmethod
@ -298,7 +336,7 @@ class WaveGlow(torch.nn.Module):
for WN in waveglow.WN: for WN in waveglow.WN:
WN.start = torch.nn.utils.remove_weight_norm(WN.start) WN.start = torch.nn.utils.remove_weight_norm(WN.start)
WN.in_layers = remove(WN.in_layers) WN.in_layers = remove(WN.in_layers)
WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer) WN.cond_layers = remove(WN.cond_layers)
WN.res_skip_layers = remove(WN.res_skip_layers) WN.res_skip_layers = remove(WN.res_skip_layers)
return waveglow return waveglow

View File

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# ***************************************************************************** # *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# #
@ -13,18 +12,19 @@
# names of its contributors may be used to endorse or promote products # names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission. # derived from this software without specific prior written permission.
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
# ***************************************************************************** # *****************************************************************************
import copy
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
import torch.nn.functional as F import torch.nn.functional as F
@ -34,8 +34,8 @@ import torch.nn.functional as F
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0] n_channels_int = n_channels[0]
in_act = input_a+input_b in_act = input_a+input_b
t_act = torch.nn.functional.tanh(in_act[:, :n_channels_int, :]) t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.nn.functional.sigmoid(in_act[:, n_channels_int:, :]) s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act acts = t_act * s_act
return acts return acts
@ -55,11 +55,7 @@ class WaveGlowLoss(torch.nn.Module):
log_s_total = log_s_total + torch.sum(log_s) log_s_total = log_s_total + torch.sum(log_s)
log_det_W_total += log_det_W_list[i] log_det_W_total += log_det_W_list[i]
loss = ( loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total
torch.sum(z * z) / (2 * self.sigma * self.sigma)
- log_s_total
- log_det_W_total
)
return loss/(z.size(0)*z.size(1)*z.size(2)) return loss/(z.size(0)*z.size(1)*z.size(2))
@ -69,12 +65,10 @@ class Invertible1x1Conv(torch.nn.Module):
of its weight matrix. If reverse=True it does convolution with of its weight matrix. If reverse=True it does convolution with
inverse inverse
""" """
def __init__(self, c): def __init__(self, c):
super(Invertible1x1Conv, self).__init__() super(Invertible1x1Conv, self).__init__()
self.conv = torch.nn.Conv1d( self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
c, c, kernel_size=1, stride=1, padding=0, bias=False bias=False)
)
# Sample a random orthonormal matrix to initialize weights # Sample a random orthonormal matrix to initialize weights
W = torch.qr(torch.FloatTensor(c, c).normal_())[0] W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
@ -92,11 +86,11 @@ class Invertible1x1Conv(torch.nn.Module):
W = self.conv.weight.squeeze() W = self.conv.weight.squeeze()
if reverse: if reverse:
if not hasattr(self, "W_inverse"): if not hasattr(self, 'W_inverse'):
# Reverse computation # Reverse computation
W_inverse = W.inverse() W_inverse = W.float().inverse()
W_inverse = Variable(W_inverse[..., None]) W_inverse = Variable(W_inverse[..., None])
if z.type() == "torch.cuda.HalfTensor": if z.type() == 'torch.HalfTensor':
W_inverse = W_inverse.half() W_inverse = W_inverse.half()
self.W_inverse = W_inverse self.W_inverse = W_inverse
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
@ -110,25 +104,22 @@ class Invertible1x1Conv(torch.nn.Module):
class WN(torch.nn.Module): class WN(torch.nn.Module):
""" """
This is the WaveNet like layer for the affine coupling. The primary This is the WaveNet like layer for the affine coupling. The primary difference
difference from WaveNet is the convolutions need not be causal. There is from WaveNet is the convolutions need not be causal. There is also no dilation
also no dilation size reset. The dilation only doubles on each layer size reset. The dilation only doubles on each layer
""" """
def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
def __init__( kernel_size):
self, n_in_channels, n_mel_channels, n_layers, n_channels, kernel_size
):
super(WN, self).__init__() super(WN, self).__init__()
assert kernel_size % 2 == 1 assert(kernel_size % 2 == 1)
assert n_channels % 2 == 0 assert(n_channels % 2 == 0)
self.n_layers = n_layers self.n_layers = n_layers
self.n_channels = n_channels self.n_channels = n_channels
self.in_layers = torch.nn.ModuleList() self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList() self.res_skip_layers = torch.nn.ModuleList()
self.cond_layers = torch.nn.ModuleList()
start = torch.nn.Conv1d(n_in_channels, n_channels, 1) start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
start = torch.nn.utils.weight_norm(start, name="weight") start = torch.nn.utils.weight_norm(start, name='weight')
self.start = start self.start = start
# Initializing last layer to 0 makes the affine coupling layers # Initializing last layer to 0 makes the affine coupling layers
@ -138,22 +129,17 @@ class WN(torch.nn.Module):
end.bias.data.zero_() end.bias.data.zero_()
self.end = end self.end = end
cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
for i in range(n_layers): for i in range(n_layers):
dilation = 2 ** i dilation = 2 ** i
padding = int((kernel_size*dilation - dilation)/2) padding = int((kernel_size*dilation - dilation)/2)
in_layer = torch.nn.Conv1d( in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
n_channels, dilation=dilation, padding=padding)
2 * n_channels, in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer) self.in_layers.append(in_layer)
cond_layer = torch.nn.Conv1d(n_mel_channels, 2 * n_channels, 1)
cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
self.cond_layers.append(cond_layer)
# last one is not necessary # last one is not necessary
if i < n_layers - 1: if i < n_layers - 1:
@ -161,51 +147,43 @@ class WN(torch.nn.Module):
else: else:
res_skip_channels = n_channels res_skip_channels = n_channels
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm( res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
res_skip_layer, name="weight"
)
self.res_skip_layers.append(res_skip_layer) self.res_skip_layers.append(res_skip_layer)
def forward(self, forward_input): def forward(self, forward_input):
audio, spect = forward_input audio, spect = forward_input
audio = self.start(audio) audio = self.start(audio)
output = torch.zeros_like(audio)
n_channels_tensor = torch.IntTensor([self.n_channels])
spect = self.cond_layer(spect)
for i in range(self.n_layers): for i in range(self.n_layers):
spect_offset = i*2*self.n_channels
acts = fused_add_tanh_sigmoid_multiply( acts = fused_add_tanh_sigmoid_multiply(
self.in_layers[i](audio), self.in_layers[i](audio),
self.cond_layers[i](spect), spect[:,spect_offset:spect_offset+2*self.n_channels,:],
torch.IntTensor([self.n_channels]), n_channels_tensor)
)
res_skip_acts = self.res_skip_layers[i](acts) res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1: if i < self.n_layers - 1:
audio = res_skip_acts[:, : self.n_channels, :] + audio audio = audio + res_skip_acts[:,:self.n_channels,:]
skip_acts = res_skip_acts[:, self.n_channels :, :] output = output + res_skip_acts[:,self.n_channels:,:]
else: else:
skip_acts = res_skip_acts output = output + res_skip_acts
if i == 0:
output = skip_acts
else:
output = skip_acts + output
return self.end(output) return self.end(output)
class WaveGlow(torch.nn.Module): class WaveGlow(torch.nn.Module):
def __init__( def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
self, n_early_size, WN_config):
n_mel_channels,
n_flows,
n_group,
n_early_every,
n_early_size,
WN_config,
):
super(WaveGlow, self).__init__() super(WaveGlow, self).__init__()
self.upsample = torch.nn.ConvTranspose1d( self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
n_mel_channels, n_mel_channels, 1024, stride=256 n_mel_channels,
) 1024, stride=256)
assert n_group % 2 == 0 assert(n_group % 2 == 0)
self.n_flows = n_flows self.n_flows = n_flows
self.n_group = n_group self.n_group = n_group
self.n_early_every = n_early_every self.n_early_every = n_early_every
@ -224,8 +202,7 @@ class WaveGlow(torch.nn.Module):
n_remaining_channels = n_remaining_channels - self.n_early_size n_remaining_channels = n_remaining_channels - self.n_early_size
self.convinv.append(Invertible1x1Conv(n_remaining_channels)) self.convinv.append(Invertible1x1Conv(n_remaining_channels))
self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config))
self.n_remaining_channels = n_remaining_channels self.n_remaining_channels = n_remaining_channels # Useful during inference
# Useful during inference
def forward(self, forward_input): def forward(self, forward_input):
""" """
@ -236,16 +213,12 @@ class WaveGlow(torch.nn.Module):
# Upsample spectrogram to size of audio # Upsample spectrogram to size of audio
spect = self.upsample(spect) spect = self.upsample(spect)
assert spect.size(2) >= audio.size(1) assert(spect.size(2) >= audio.size(1))
if spect.size(2) > audio.size(1): if spect.size(2) > audio.size(1):
spect = spect[:, :, :audio.size(1)] spect = spect[:, :, :audio.size(1)]
spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
spect = ( spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
spect.contiguous()
.view(spect.size(0), spect.size(1), -1)
.permute(0, 2, 1)
)
audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
output_audio = [] output_audio = []
@ -282,21 +255,16 @@ class WaveGlow(torch.nn.Module):
spect = spect[:, :, :-time_cutoff] spect = spect[:, :, :-time_cutoff]
spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
spect = ( spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
spect.contiguous()
.view(spect.size(0), spect.size(1), -1)
.permute(0, 2, 1)
)
if spect.type() == "torch.cuda.HalfTensor": if spect.type() == 'torch.HalfTensor':
audio = torch.cuda.HalfTensor( audio = torch.HalfTensor(spect.size(0),
spect.size(0), self.n_remaining_channels, spect.size(2) self.n_remaining_channels,
).normal_() spect.size(2)).normal_()
else: else:
# cuda.FloatTensor -> FloatTensor audio = torch.FloatTensor(spect.size(0),
audio = torch.FloatTensor( self.n_remaining_channels,
spect.size(0), self.n_remaining_channels, spect.size(2) spect.size(2)).normal_()
).normal_()
audio = torch.autograd.Variable(sigma*audio) audio = torch.autograd.Variable(sigma*audio)
@ -306,6 +274,7 @@ class WaveGlow(torch.nn.Module):
audio_1 = audio[:,n_half:,:] audio_1 = audio[:,n_half:,:]
output = self.WN[k]((audio_0, spect)) output = self.WN[k]((audio_0, spect))
s = output[:, n_half:, :] s = output[:, n_half:, :]
b = output[:, :n_half, :] b = output[:, :n_half, :]
audio_1 = (audio_1 - b)/torch.exp(s) audio_1 = (audio_1 - b)/torch.exp(s)
@ -314,20 +283,13 @@ class WaveGlow(torch.nn.Module):
audio = self.convinv[k](audio, reverse=True) audio = self.convinv[k](audio, reverse=True)
if k % self.n_early_every == 0 and k > 0: if k % self.n_early_every == 0 and k > 0:
if spect.type() == "torch.cuda.HalfTensor": if spect.type() == 'torch.HalfTensor':
z = torch.cuda.HalfTensor( z = torch.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
spect.size(0), self.n_early_size, spect.size(2)
).normal_()
else: else:
# cuda.FloatTensor -> FloatTensor z = torch.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_()
z = torch.FloatTensor(
spect.size(0), self.n_early_size, spect.size(2)
).normal_()
audio = torch.cat((sigma*z, audio),1) audio = torch.cat((sigma*z, audio),1)
audio = ( audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
audio.permute(0, 2, 1).contiguous().view(audio.size(0), -1).data
)
return audio return audio
@staticmethod @staticmethod
@ -336,7 +298,7 @@ class WaveGlow(torch.nn.Module):
for WN in waveglow.WN: for WN in waveglow.WN:
WN.start = torch.nn.utils.remove_weight_norm(WN.start) WN.start = torch.nn.utils.remove_weight_norm(WN.start)
WN.in_layers = remove(WN.in_layers) WN.in_layers = remove(WN.in_layers)
WN.cond_layers = remove(WN.cond_layers) WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer)
WN.res_skip_layers = remove(WN.res_skip_layers) WN.res_skip_layers = remove(WN.res_skip_layers)
return waveglow return waveglow

View File

@ -39,9 +39,9 @@ class HParams(object):
filter_length = 1024 filter_length = 1024
hop_length = 256 hop_length = 256
win_length = 1024 win_length = 1024
n_mel_channels: int = 40 n_mel_channels: int = 80
mel_fmin: float = 0.0 mel_fmin: float = 0.0
mel_fmax: float = 4000.0 mel_fmax: float = 8000.0
################################ ################################
# Model Parameters # # Model Parameters #
################################ ################################