mirror of https://github.com/malarinv/tacotron2
model.py: mixed squeeze target. fixing
parent
424b2f5bf0
commit
dcd925f6c8
10
model.py
10
model.py
|
|
@ -402,9 +402,8 @@ class Decoder(nn.Module):
|
||||||
while len(mel_outputs) < decoder_inputs.size(0):
|
while len(mel_outputs) < decoder_inputs.size(0):
|
||||||
mel_output, gate_output, attention_weights = self.decode(
|
mel_output, gate_output, attention_weights = self.decode(
|
||||||
decoder_input)
|
decoder_input)
|
||||||
|
mel_outputs += [mel_output]
|
||||||
mel_outputs += [mel_output.squeeze(1)]
|
gate_outputs += [gate_output.squeeze(1)]
|
||||||
gate_outputs += [gate_output.squeeze()]
|
|
||||||
alignments += [attention_weights]
|
alignments += [attention_weights]
|
||||||
|
|
||||||
decoder_input = decoder_inputs[len(mel_outputs) - 1]
|
decoder_input = decoder_inputs[len(mel_outputs) - 1]
|
||||||
|
|
@ -431,12 +430,11 @@ class Decoder(nn.Module):
|
||||||
self.initialize_decoder_states(memory, mask=None)
|
self.initialize_decoder_states(memory, mask=None)
|
||||||
|
|
||||||
mel_outputs, gate_outputs, alignments = [], [], []
|
mel_outputs, gate_outputs, alignments = [], [], []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
mel_output, gate_output, alignment = self.decode(decoder_input)
|
mel_output, gate_output, alignment = self.decode(decoder_input)
|
||||||
|
|
||||||
mel_outputs += [mel_output.squeeze(1)]
|
mel_outputs += [mel_output]
|
||||||
gate_outputs += [gate_output.squeeze()]
|
gate_outputs += [gate_output.squeeze(1)]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
if F.sigmoid(gate_output.data) > self.gate_threshold:
|
if F.sigmoid(gate_output.data) > self.gate_threshold:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue