mirror of https://github.com/malarinv/tacotron2
model.py: renaming variables, removing dropout from lstm cell state, removing conversions now handled by amp
parent
087c86755f
commit
1480f82908
22
model.py
22
model.py
|
|
@ -5,7 +5,6 @@ from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from layers import ConvNorm, LinearNorm
|
from layers import ConvNorm, LinearNorm
|
||||||
from utils import to_gpu, get_mask_from_lengths
|
from utils import to_gpu, get_mask_from_lengths
|
||||||
from fp16_optimizer import fp32_to_fp16, fp16_to_fp32
|
|
||||||
|
|
||||||
|
|
||||||
class LocationLayer(nn.Module):
|
class LocationLayer(nn.Module):
|
||||||
|
|
@ -355,8 +354,6 @@ class Decoder(nn.Module):
|
||||||
cell_input, (self.attention_hidden, self.attention_cell))
|
cell_input, (self.attention_hidden, self.attention_cell))
|
||||||
self.attention_hidden = F.dropout(
|
self.attention_hidden = F.dropout(
|
||||||
self.attention_hidden, self.p_attention_dropout, self.training)
|
self.attention_hidden, self.p_attention_dropout, self.training)
|
||||||
self.attention_cell = F.dropout(
|
|
||||||
self.attention_cell, self.p_attention_dropout, self.training)
|
|
||||||
|
|
||||||
attention_weights_cat = torch.cat(
|
attention_weights_cat = torch.cat(
|
||||||
(self.attention_weights.unsqueeze(1),
|
(self.attention_weights.unsqueeze(1),
|
||||||
|
|
@ -372,8 +369,6 @@ class Decoder(nn.Module):
|
||||||
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
||||||
self.decoder_hidden = F.dropout(
|
self.decoder_hidden = F.dropout(
|
||||||
self.decoder_hidden, self.p_decoder_dropout, self.training)
|
self.decoder_hidden, self.p_decoder_dropout, self.training)
|
||||||
self.decoder_cell = F.dropout(
|
|
||||||
self.decoder_cell, self.p_decoder_dropout, self.training)
|
|
||||||
|
|
||||||
decoder_hidden_attention_context = torch.cat(
|
decoder_hidden_attention_context = torch.cat(
|
||||||
(self.decoder_hidden, self.attention_context), dim=1)
|
(self.decoder_hidden, self.attention_context), dim=1)
|
||||||
|
|
@ -489,10 +484,6 @@ class Tacotron2(nn.Module):
|
||||||
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
|
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
|
||||||
(mel_padded, gate_padded))
|
(mel_padded, gate_padded))
|
||||||
|
|
||||||
def parse_input(self, inputs):
|
|
||||||
inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def parse_output(self, outputs, output_lengths=None):
|
def parse_output(self, outputs, output_lengths=None):
|
||||||
if self.mask_padding and output_lengths is not None:
|
if self.mask_padding and output_lengths is not None:
|
||||||
mask = ~get_mask_from_lengths(output_lengths)
|
mask = ~get_mask_from_lengths(output_lengths)
|
||||||
|
|
@ -503,20 +494,18 @@ class Tacotron2(nn.Module):
|
||||||
outputs[1].data.masked_fill_(mask, 0.0)
|
outputs[1].data.masked_fill_(mask, 0.0)
|
||||||
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
|
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
|
||||||
|
|
||||||
outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
inputs, input_lengths, targets, max_len, \
|
text_inputs, text_lengths, mels, max_len, output_lengths = inputs
|
||||||
output_lengths = self.parse_input(inputs)
|
text_lengths, output_lengths = text_lengths.data, output_lengths.data
|
||||||
input_lengths, output_lengths = input_lengths.data, output_lengths.data
|
|
||||||
|
|
||||||
embedded_inputs = self.embedding(inputs).transpose(1, 2)
|
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(embedded_inputs, input_lengths)
|
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||||
|
|
||||||
mel_outputs, gate_outputs, alignments = self.decoder(
|
mel_outputs, gate_outputs, alignments = self.decoder(
|
||||||
encoder_outputs, targets, memory_lengths=input_lengths)
|
encoder_outputs, mels, memory_lengths=text_lengths)
|
||||||
|
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
|
@ -526,7 +515,6 @@ class Tacotron2(nn.Module):
|
||||||
output_lengths)
|
output_lengths)
|
||||||
|
|
||||||
def inference(self, inputs):
|
def inference(self, inputs):
|
||||||
inputs = self.parse_input(inputs)
|
|
||||||
embedded_inputs = self.embedding(inputs).transpose(1, 2)
|
embedded_inputs = self.embedding(inputs).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||||
mel_outputs, gate_outputs, alignments = self.decoder.inference(
|
mel_outputs, gate_outputs, alignments = self.decoder.inference(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue