mirror of https://github.com/malarinv/tacotron2
model.py: renaming variables, removing dropout from lstm cell state, removing conversions now handled by amp
parent
087c86755f
commit
1480f82908
22
model.py
22
model.py
|
|
@ -5,7 +5,6 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
from layers import ConvNorm, LinearNorm
|
||||
from utils import to_gpu, get_mask_from_lengths
|
||||
from fp16_optimizer import fp32_to_fp16, fp16_to_fp32
|
||||
|
||||
|
||||
class LocationLayer(nn.Module):
|
||||
|
|
@ -355,8 +354,6 @@ class Decoder(nn.Module):
|
|||
cell_input, (self.attention_hidden, self.attention_cell))
|
||||
self.attention_hidden = F.dropout(
|
||||
self.attention_hidden, self.p_attention_dropout, self.training)
|
||||
self.attention_cell = F.dropout(
|
||||
self.attention_cell, self.p_attention_dropout, self.training)
|
||||
|
||||
attention_weights_cat = torch.cat(
|
||||
(self.attention_weights.unsqueeze(1),
|
||||
|
|
@ -372,8 +369,6 @@ class Decoder(nn.Module):
|
|||
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
||||
self.decoder_hidden = F.dropout(
|
||||
self.decoder_hidden, self.p_decoder_dropout, self.training)
|
||||
self.decoder_cell = F.dropout(
|
||||
self.decoder_cell, self.p_decoder_dropout, self.training)
|
||||
|
||||
decoder_hidden_attention_context = torch.cat(
|
||||
(self.decoder_hidden, self.attention_context), dim=1)
|
||||
|
|
@ -489,10 +484,6 @@ class Tacotron2(nn.Module):
|
|||
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
|
||||
(mel_padded, gate_padded))
|
||||
|
||||
def parse_input(self, inputs):
|
||||
inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs
|
||||
return inputs
|
||||
|
||||
def parse_output(self, outputs, output_lengths=None):
|
||||
if self.mask_padding and output_lengths is not None:
|
||||
mask = ~get_mask_from_lengths(output_lengths)
|
||||
|
|
@ -503,20 +494,18 @@ class Tacotron2(nn.Module):
|
|||
outputs[1].data.masked_fill_(mask, 0.0)
|
||||
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
|
||||
|
||||
outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
|
||||
return outputs
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs, input_lengths, targets, max_len, \
|
||||
output_lengths = self.parse_input(inputs)
|
||||
input_lengths, output_lengths = input_lengths.data, output_lengths.data
|
||||
text_inputs, text_lengths, mels, max_len, output_lengths = inputs
|
||||
text_lengths, output_lengths = text_lengths.data, output_lengths.data
|
||||
|
||||
embedded_inputs = self.embedding(inputs).transpose(1, 2)
|
||||
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
||||
|
||||
encoder_outputs = self.encoder(embedded_inputs, input_lengths)
|
||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
|
||||
mel_outputs, gate_outputs, alignments = self.decoder(
|
||||
encoder_outputs, targets, memory_lengths=input_lengths)
|
||||
encoder_outputs, mels, memory_lengths=text_lengths)
|
||||
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
|
|
@ -526,7 +515,6 @@ class Tacotron2(nn.Module):
|
|||
output_lengths)
|
||||
|
||||
def inference(self, inputs):
|
||||
inputs = self.parse_input(inputs)
|
||||
embedded_inputs = self.embedding(inputs).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||
mel_outputs, gate_outputs, alignments = self.decoder.inference(
|
||||
|
|
|
|||
Loading…
Reference in New Issue