mirror of https://github.com/malarinv/tacotron2
Merge pull request #23 from NVIDIA/attention_full_mel
model.py: attending to full mel instead of prenet and dropout mellatest_model
commit
064629c9bc
11
model.py
11
model.py
|
|
@ -221,7 +221,7 @@ class Decoder(nn.Module):
|
||||||
[hparams.prenet_dim, hparams.prenet_dim])
|
[hparams.prenet_dim, hparams.prenet_dim])
|
||||||
|
|
||||||
self.attention_rnn = nn.LSTMCell(
|
self.attention_rnn = nn.LSTMCell(
|
||||||
hparams.prenet_dim + hparams.encoder_embedding_dim,
|
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
|
||||||
hparams.attention_rnn_dim)
|
hparams.attention_rnn_dim)
|
||||||
|
|
||||||
self.attention_layer = Attention(
|
self.attention_layer = Attention(
|
||||||
|
|
@ -230,7 +230,7 @@ class Decoder(nn.Module):
|
||||||
hparams.attention_location_kernel_size)
|
hparams.attention_location_kernel_size)
|
||||||
|
|
||||||
self.decoder_rnn = nn.LSTMCell(
|
self.decoder_rnn = nn.LSTMCell(
|
||||||
hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
|
hparams.prenet_dim + hparams.encoder_embedding_dim,
|
||||||
hparams.decoder_rnn_dim, 1)
|
hparams.decoder_rnn_dim, 1)
|
||||||
|
|
||||||
self.linear_projection = LinearNorm(
|
self.linear_projection = LinearNorm(
|
||||||
|
|
@ -351,8 +351,7 @@ class Decoder(nn.Module):
|
||||||
attention_weights:
|
attention_weights:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
decoder_input = self.prenet(decoder_input)
|
cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1)
|
||||||
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
|
||||||
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
||||||
cell_input, (self.attention_hidden, self.attention_cell))
|
cell_input, (self.attention_hidden, self.attention_cell))
|
||||||
|
|
||||||
|
|
@ -364,8 +363,8 @@ class Decoder(nn.Module):
|
||||||
attention_weights_cat, self.mask)
|
attention_weights_cat, self.mask)
|
||||||
|
|
||||||
self.attention_weights_cum += self.attention_weights
|
self.attention_weights_cum += self.attention_weights
|
||||||
decoder_input = torch.cat(
|
prenet_output = self.prenet(decoder_input)
|
||||||
(self.attention_hidden, self.attention_context), -1)
|
decoder_input = torch.cat((prenet_output, self.attention_context), -1)
|
||||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||||
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue