1
0
mirror of https://github.com/malarinv/tacotron2 synced 2026-03-08 09:42:34 +00:00

6 Commits

Author SHA1 Message Date
Rafael Valle
d5b64729d1 model.py: moving for better readibility 2018-05-20 12:22:06 -07:00
Rafael Valle
977cb37cea model.py: attending to full mel instead of prenet and dropout mel 2018-05-18 06:59:09 -07:00
Rafael Valle
da30fd8709 Merge pull request #20 from NVIDIA/fp16_path
Fp16 patch, not path!
2018-05-15 09:55:19 -07:00
Rafael Valle
27b1767cb2 train.py: fixing typo 2018-05-15 09:53:33 -07:00
Rafael Valle
817cd403d4 Merge branch 'master' of https://github.com/NVIDIA/tacotron2 into load_mel_from_disk 2018-05-15 09:51:41 -07:00
Rafael Valle
bd42cb6ed7 Merge pull request #19 from NVIDIA/load_mel_from_disk
Load mel from disk
2018-05-15 08:54:24 -07:00
2 changed files with 6 additions and 7 deletions

View File

@@ -221,7 +221,7 @@ class Decoder(nn.Module):
[hparams.prenet_dim, hparams.prenet_dim])
self.attention_rnn = nn.LSTMCell(
hparams.prenet_dim + hparams.encoder_embedding_dim,
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
hparams.attention_rnn_dim)
self.attention_layer = Attention(
@@ -230,7 +230,7 @@ class Decoder(nn.Module):
hparams.attention_location_kernel_size)
self.decoder_rnn = nn.LSTMCell(
hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
hparams.prenet_dim + hparams.encoder_embedding_dim,
hparams.decoder_rnn_dim, 1)
self.linear_projection = LinearNorm(
@@ -351,8 +351,7 @@ class Decoder(nn.Module):
attention_weights:
"""
decoder_input = self.prenet(decoder_input)
cell_input = torch.cat((decoder_input, self.attention_context), -1)
cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
@@ -364,8 +363,8 @@ class Decoder(nn.Module):
attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights
decoder_input = torch.cat(
(self.attention_hidden, self.attention_context), -1)
prenet_output = self.prenet(decoder_input)
decoder_input = torch.cat((prenet_output, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell))

View File

@@ -279,7 +279,7 @@ if __name__ == '__main__':
torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
print("FP16 Run:", hparams.fp16_run)
print("Dynamic Loss Scaling", hparams.dynamic_loss_scaling)
print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling)
print("Distributed Run:", hparams.distributed_run)
print("cuDNN Enabled:", hparams.cudnn_enabled)
print("cuDNN Benchmark:", hparams.cudnn_benchmark)