mirror of https://github.com/malarinv/tacotron2
model.py: rewrite
parent
1ec0e5e8cd
commit
4af4ccb135
77
model.py
77
model.py
|
|
@ -1,3 +1,4 @@
|
||||||
|
from math import sqrt
|
||||||
import torch
|
import torch
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
@ -56,7 +57,7 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
processed_query = self.query_layer(query.unsqueeze(1))
|
processed_query = self.query_layer(query.unsqueeze(1))
|
||||||
processed_attention_weights = self.location_layer(attention_weights_cat)
|
processed_attention_weights = self.location_layer(attention_weights_cat)
|
||||||
energies = self.v(F.tanh(
|
energies = self.v(torch.tanh(
|
||||||
processed_query + processed_attention_weights + processed_memory))
|
processed_query + processed_attention_weights + processed_memory))
|
||||||
|
|
||||||
energies = energies.squeeze(-1)
|
energies = energies.squeeze(-1)
|
||||||
|
|
@ -107,7 +108,6 @@ class Postnet(nn.Module):
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
super(Postnet, self).__init__()
|
super(Postnet, self).__init__()
|
||||||
self.dropout = nn.Dropout(0.5)
|
|
||||||
self.convolutions = nn.ModuleList()
|
self.convolutions = nn.ModuleList()
|
||||||
|
|
||||||
self.convolutions.append(
|
self.convolutions.append(
|
||||||
|
|
@ -141,9 +141,8 @@ class Postnet(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for i in range(len(self.convolutions) - 1):
|
for i in range(len(self.convolutions) - 1):
|
||||||
x = self.dropout(F.tanh(self.convolutions[i](x)))
|
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
|
||||||
|
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
|
||||||
x = self.dropout(self.convolutions[-1](x))
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
@ -155,7 +154,6 @@ class Encoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
super(Encoder, self).__init__()
|
super(Encoder, self).__init__()
|
||||||
self.dropout = nn.Dropout(0.5)
|
|
||||||
|
|
||||||
convolutions = []
|
convolutions = []
|
||||||
for _ in range(hparams.encoder_n_convolutions):
|
for _ in range(hparams.encoder_n_convolutions):
|
||||||
|
|
@ -175,7 +173,7 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, input_lengths):
|
def forward(self, x, input_lengths):
|
||||||
for conv in self.convolutions:
|
for conv in self.convolutions:
|
||||||
x = self.dropout(F.relu(conv(x)))
|
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
||||||
|
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
|
|
||||||
|
|
@ -194,7 +192,7 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
def inference(self, x):
|
def inference(self, x):
|
||||||
for conv in self.convolutions:
|
for conv in self.convolutions:
|
||||||
x = self.dropout(F.relu(conv(x)))
|
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
||||||
|
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
|
|
||||||
|
|
@ -215,13 +213,15 @@ class Decoder(nn.Module):
|
||||||
self.prenet_dim = hparams.prenet_dim
|
self.prenet_dim = hparams.prenet_dim
|
||||||
self.max_decoder_steps = hparams.max_decoder_steps
|
self.max_decoder_steps = hparams.max_decoder_steps
|
||||||
self.gate_threshold = hparams.gate_threshold
|
self.gate_threshold = hparams.gate_threshold
|
||||||
|
self.p_attention_dropout = hparams.p_attention_dropout
|
||||||
|
self.p_decoder_dropout = hparams.p_decoder_dropout
|
||||||
|
|
||||||
self.prenet = Prenet(
|
self.prenet = Prenet(
|
||||||
hparams.n_mel_channels * hparams.n_frames_per_step,
|
hparams.n_mel_channels * hparams.n_frames_per_step,
|
||||||
[hparams.prenet_dim, hparams.prenet_dim])
|
[hparams.prenet_dim, hparams.prenet_dim])
|
||||||
|
|
||||||
self.attention_rnn = nn.LSTMCell(
|
self.attention_rnn = nn.LSTMCell(
|
||||||
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
|
hparams.prenet_dim + hparams.encoder_embedding_dim,
|
||||||
hparams.attention_rnn_dim)
|
hparams.attention_rnn_dim)
|
||||||
|
|
||||||
self.attention_layer = Attention(
|
self.attention_layer = Attention(
|
||||||
|
|
@ -230,12 +230,12 @@ class Decoder(nn.Module):
|
||||||
hparams.attention_location_kernel_size)
|
hparams.attention_location_kernel_size)
|
||||||
|
|
||||||
self.decoder_rnn = nn.LSTMCell(
|
self.decoder_rnn = nn.LSTMCell(
|
||||||
hparams.prenet_dim + hparams.encoder_embedding_dim,
|
hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
|
||||||
hparams.decoder_rnn_dim, 1)
|
hparams.decoder_rnn_dim, 1)
|
||||||
|
|
||||||
self.linear_projection = LinearNorm(
|
self.linear_projection = LinearNorm(
|
||||||
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
|
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
|
||||||
hparams.n_mel_channels*hparams.n_frames_per_step)
|
hparams.n_mel_channels * hparams.n_frames_per_step)
|
||||||
|
|
||||||
self.gate_layer = LinearNorm(
|
self.gate_layer = LinearNorm(
|
||||||
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
|
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
|
||||||
|
|
@ -350,10 +350,13 @@ class Decoder(nn.Module):
|
||||||
gate_output: gate output energies
|
gate_output: gate output energies
|
||||||
attention_weights:
|
attention_weights:
|
||||||
"""
|
"""
|
||||||
|
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
||||||
cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1)
|
|
||||||
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
||||||
cell_input, (self.attention_hidden, self.attention_cell))
|
cell_input, (self.attention_hidden, self.attention_cell))
|
||||||
|
self.attention_hidden = F.dropout(
|
||||||
|
self.attention_hidden, self.p_attention_dropout, self.training)
|
||||||
|
self.attention_cell = F.dropout(
|
||||||
|
self.attention_cell, self.p_attention_dropout, self.training)
|
||||||
|
|
||||||
attention_weights_cat = torch.cat(
|
attention_weights_cat = torch.cat(
|
||||||
(self.attention_weights.unsqueeze(1),
|
(self.attention_weights.unsqueeze(1),
|
||||||
|
|
@ -363,10 +366,14 @@ class Decoder(nn.Module):
|
||||||
attention_weights_cat, self.mask)
|
attention_weights_cat, self.mask)
|
||||||
|
|
||||||
self.attention_weights_cum += self.attention_weights
|
self.attention_weights_cum += self.attention_weights
|
||||||
prenet_output = self.prenet(decoder_input)
|
decoder_input = torch.cat(
|
||||||
decoder_input = torch.cat((prenet_output, self.attention_context), -1)
|
(self.attention_hidden, self.attention_context), -1)
|
||||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||||
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
||||||
|
self.decoder_hidden = F.dropout(
|
||||||
|
self.decoder_hidden, self.p_decoder_dropout, self.training)
|
||||||
|
self.decoder_cell = F.dropout(
|
||||||
|
self.decoder_cell, self.p_decoder_dropout, self.training)
|
||||||
|
|
||||||
decoder_hidden_attention_context = torch.cat(
|
decoder_hidden_attention_context = torch.cat(
|
||||||
(self.decoder_hidden, self.attention_context), dim=1)
|
(self.decoder_hidden, self.attention_context), dim=1)
|
||||||
|
|
@ -391,22 +398,23 @@ class Decoder(nn.Module):
|
||||||
alignments: sequence of attention weights from the decoder
|
alignments: sequence of attention weights from the decoder
|
||||||
"""
|
"""
|
||||||
|
|
||||||
decoder_input = self.get_go_frame(memory)
|
decoder_input = self.get_go_frame(memory).unsqueeze(0)
|
||||||
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
|
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
|
||||||
|
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
|
||||||
|
decoder_inputs = self.prenet(decoder_inputs)
|
||||||
|
|
||||||
self.initialize_decoder_states(
|
self.initialize_decoder_states(
|
||||||
memory, mask=~get_mask_from_lengths(memory_lengths))
|
memory, mask=~get_mask_from_lengths(memory_lengths))
|
||||||
|
|
||||||
mel_outputs, gate_outputs, alignments = [], [], []
|
mel_outputs, gate_outputs, alignments = [], [], []
|
||||||
|
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
||||||
while len(mel_outputs) < decoder_inputs.size(0):
|
decoder_input = decoder_inputs[len(mel_outputs)]
|
||||||
mel_output, gate_output, attention_weights = self.decode(
|
mel_output, gate_output, attention_weights = self.decode(
|
||||||
decoder_input)
|
decoder_input)
|
||||||
mel_outputs += [mel_output]
|
mel_outputs += [mel_output.squeeze(1)]
|
||||||
gate_outputs += [gate_output.squeeze(1)]
|
gate_outputs += [gate_output.squeeze()]
|
||||||
alignments += [attention_weights]
|
alignments += [attention_weights]
|
||||||
|
|
||||||
decoder_input = decoder_inputs[len(mel_outputs) - 1]
|
|
||||||
|
|
||||||
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
|
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
|
||||||
mel_outputs, gate_outputs, alignments)
|
mel_outputs, gate_outputs, alignments)
|
||||||
|
|
||||||
|
|
@ -430,13 +438,14 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
mel_outputs, gate_outputs, alignments = [], [], []
|
mel_outputs, gate_outputs, alignments = [], [], []
|
||||||
while True:
|
while True:
|
||||||
|
decoder_input = self.prenet(decoder_input)
|
||||||
mel_output, gate_output, alignment = self.decode(decoder_input)
|
mel_output, gate_output, alignment = self.decode(decoder_input)
|
||||||
|
|
||||||
mel_outputs += [mel_output]
|
mel_outputs += [mel_output.squeeze(1)]
|
||||||
gate_outputs += [gate_output.squeeze(1)]
|
gate_outputs += [gate_output]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
if F.sigmoid(gate_output.data) > self.gate_threshold:
|
if torch.sigmoid(gate_output.data) > self.gate_threshold:
|
||||||
break
|
break
|
||||||
elif len(mel_outputs) == self.max_decoder_steps:
|
elif len(mel_outputs) == self.max_decoder_steps:
|
||||||
print("Warning! Reached max decoder steps")
|
print("Warning! Reached max decoder steps")
|
||||||
|
|
@ -459,8 +468,9 @@ class Tacotron2(nn.Module):
|
||||||
self.n_frames_per_step = hparams.n_frames_per_step
|
self.n_frames_per_step = hparams.n_frames_per_step
|
||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(
|
||||||
hparams.n_symbols, hparams.symbols_embedding_dim)
|
hparams.n_symbols, hparams.symbols_embedding_dim)
|
||||||
torch.nn.init.xavier_uniform_(self.embedding.weight.data)
|
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
|
||||||
|
val = sqrt(3.0) * std # uniform bounds for std
|
||||||
|
self.embedding.weight.data.uniform_(-val, val)
|
||||||
self.encoder = Encoder(hparams)
|
self.encoder = Encoder(hparams)
|
||||||
self.decoder = Decoder(hparams)
|
self.decoder = Decoder(hparams)
|
||||||
self.postnet = Postnet(hparams)
|
self.postnet = Postnet(hparams)
|
||||||
|
|
@ -469,8 +479,8 @@ class Tacotron2(nn.Module):
|
||||||
text_padded, input_lengths, mel_padded, gate_padded, \
|
text_padded, input_lengths, mel_padded, gate_padded, \
|
||||||
output_lengths = batch
|
output_lengths = batch
|
||||||
text_padded = to_gpu(text_padded).long()
|
text_padded = to_gpu(text_padded).long()
|
||||||
max_len = int(torch.max(input_lengths.data).numpy())
|
|
||||||
input_lengths = to_gpu(input_lengths).long()
|
input_lengths = to_gpu(input_lengths).long()
|
||||||
|
max_len = torch.max(input_lengths.data).item()
|
||||||
mel_padded = to_gpu(mel_padded).float()
|
mel_padded = to_gpu(mel_padded).float()
|
||||||
gate_padded = to_gpu(gate_padded).float()
|
gate_padded = to_gpu(gate_padded).float()
|
||||||
output_lengths = to_gpu(output_lengths).long()
|
output_lengths = to_gpu(output_lengths).long()
|
||||||
|
|
@ -485,7 +495,7 @@ class Tacotron2(nn.Module):
|
||||||
|
|
||||||
def parse_output(self, outputs, output_lengths=None):
|
def parse_output(self, outputs, output_lengths=None):
|
||||||
if self.mask_padding and output_lengths is not None:
|
if self.mask_padding and output_lengths is not None:
|
||||||
mask = ~get_mask_from_lengths(output_lengths+1) # +1 <stop> token
|
mask = ~get_mask_from_lengths(output_lengths)
|
||||||
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
|
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
|
||||||
mask = mask.permute(1, 0, 2)
|
mask = mask.permute(1, 0, 2)
|
||||||
|
|
||||||
|
|
@ -494,7 +504,6 @@ class Tacotron2(nn.Module):
|
||||||
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
|
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
|
||||||
|
|
||||||
outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
|
outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
|
|
@ -512,14 +521,6 @@ class Tacotron2(nn.Module):
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
|
||||||
# DataParallel expects equal sized inputs/outputs, hence padding
|
|
||||||
if input_lengths is not None:
|
|
||||||
alignments = alignments.unsqueeze(0)
|
|
||||||
alignments = nn.functional.pad(
|
|
||||||
alignments,
|
|
||||||
(0, max_len - alignments.size(3), 0, 0),
|
|
||||||
"constant", 0)
|
|
||||||
alignments = alignments.squeeze()
|
|
||||||
return self.parse_output(
|
return self.parse_output(
|
||||||
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
|
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
|
||||||
output_lengths)
|
output_lengths)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue