From f03692c73d235b759940c813e987a441795450bc Mon Sep 17 00:00:00 2001 From: rafaelvalle Date: Mon, 26 Nov 2018 16:37:44 -0800 Subject: [PATCH] cleanup, new model and waveglow --- .gitmodules | 3 ++ data_utils.py | 19 +++++++------ distributed.py | 52 ++++++++++++++++++++++++++++++++++ hparams.py | 19 +++++++------ layers.py | 6 ++-- model.py | 77 +++++++++++++++++++++++++------------------------- stft.py | 2 +- train.py | 56 ++++++++++++++++-------------------- utils.py | 21 ++++++-------- waveglow | 1 + 10 files changed, 153 insertions(+), 103 deletions(-) create mode 100644 .gitmodules create mode 160000 waveglow diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..3ec228e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "waveglow"] + path = waveglow + url = https://github.com/NVIDIA/waveglow diff --git a/data_utils.py b/data_utils.py index 09f42ac..fdfd287 100644 --- a/data_utils.py +++ b/data_utils.py @@ -14,9 +14,8 @@ class TextMelLoader(torch.utils.data.Dataset): 2) normalizes text and converts them to sequences of one-hot vectors 3) computes mel-spectrograms from audio files. """ - def __init__(self, audiopaths_and_text, hparams, shuffle=True): - self.audiopaths_and_text = load_filepaths_and_text( - audiopaths_and_text, hparams.sort_by_length) + def __init__(self, audiopaths_and_text, hparams): + self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) self.text_cleaners = hparams.text_cleaners self.max_wav_value = hparams.max_wav_value self.sampling_rate = hparams.sampling_rate @@ -26,8 +25,7 @@ class TextMelLoader(torch.utils.data.Dataset): hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, hparams.mel_fmax) random.seed(1234) - if shuffle: - random.shuffle(self.audiopaths_and_text) + random.shuffle(self.audiopaths_and_text) def get_mel_text_pair(self, audiopath_and_text): # separate filename and text @@ -38,7 +36,10 @@ class TextMelLoader(torch.utils.data.Dataset): def get_mel(self, filename): if not self.load_mel_from_disk: - audio = load_wav_to_torch(filename, self.sampling_rate) + audio, sampling_rate = load_wav_to_torch(filename) + if sampling_rate != self.stft.sampling_rate: + raise ValueError("{} {} SR doesn't match target {} SR".format( + sampling_rate, self.stft.sampling_rate)) audio_norm = audio / self.max_wav_value audio_norm = audio_norm.unsqueeze(0) audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) @@ -87,9 +88,9 @@ class TextMelCollate(): text = batch[ids_sorted_decreasing[i]][0] text_padded[i, :text.size(0)] = text - # Right zero-pad mel-spec with extra single zero vector to mark the end + # Right zero-pad mel-spec num_mels = batch[0][1].size(0) - max_target_len = max([x[1].size(1) for x in batch]) + 1 + max_target_len = max([x[1].size(1) for x in batch]) if max_target_len % self.n_frames_per_step != 0: max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step assert max_target_len % self.n_frames_per_step == 0 @@ -103,7 +104,7 @@ class TextMelCollate(): for i in range(len(ids_sorted_decreasing)): mel = batch[ids_sorted_decreasing[i]][1] mel_padded[i, :, :mel.size(1)] = mel - gate_padded[i, mel.size(1):] = 1 + gate_padded[i, mel.size(1)-1:] = 1 output_lengths[i] = mel.size(1) return text_padded, input_lengths, mel_padded, gate_padded, \ diff --git a/distributed.py b/distributed.py index ebe3b5b..07e8a5b 100644 --- a/distributed.py +++ b/distributed.py @@ -118,3 +118,55 @@ class DistributedDataParallel(Module): super(DistributedDataParallel, self).train(mode) self.module.train(mode) ''' +''' +Modifies existing model to do gradient allreduce, but doesn't change class +so you don't need "module" +''' +def apply_gradient_allreduce(module): + if not hasattr(dist, '_backend'): + module.warn_on_half = True + else: + module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False + + for p in module.state_dict().values(): + if not torch.is_tensor(p): + continue + dist.broadcast(p, 0) + + def allreduce_params(): + if(module.needs_reduction): + module.needs_reduction = False + buckets = {} + for param in module.parameters(): + if param.requires_grad and param.grad is not None: + tp = type(param.data) + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + if module.warn_on_half: + if torch.cuda.HalfTensor in buckets: + print("WARNING: gloo dist backend for half parameters may be extremely slow." + + " It is recommended to use the NCCL backend in this case. This currently requires" + + "PyTorch built from top of tree master.") + module.warn_on_half = False + + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced) + coalesced /= dist.get_world_size() + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + for param in list(module.parameters()): + def allreduce_hook(*unused): + param._execution_engine.queue_callback(allreduce_params) + if param.requires_grad: + param.register_hook(allreduce_hook) + + def set_needs_reduction(self, input, output): + self.needs_reduction = True + + module.register_forward_hook(set_needs_reduction) + return module diff --git a/hparams.py b/hparams.py index a3203e2..0f5e90c 100644 --- a/hparams.py +++ b/hparams.py @@ -10,7 +10,7 @@ def create_hparams(hparams_string=None, verbose=False): # Experiment Parameters # ################################ epochs=500, - iters_per_checkpoint=500, + iters_per_checkpoint=1000, seed=1234, dynamic_loss_scaling=True, fp16_run=False, @@ -24,10 +24,9 @@ def create_hparams(hparams_string=None, verbose=False): # Data Parameters # ################################ load_mel_from_disk=False, - training_files='filelists/ljs_audio_text_train_filelist.txt', - validation_files='filelists/ljs_audio_text_val_filelist.txt', + training_files='filelists/ljs_audio22khz_text_train_filelist.txt', + validation_files='filelists/ljs_audio22khz_text_val_filelist.txt', text_cleaners=['english_cleaners'], - sort_by_length=False, ################################ # Audio Parameters # @@ -39,7 +38,7 @@ def create_hparams(hparams_string=None, verbose=False): win_length=1024, n_mel_channels=80, mel_fmin=0.0, - mel_fmax=None, # if None, half the sampling rate + mel_fmax=8000.0, ################################ # Model Parameters # @@ -57,7 +56,9 @@ def create_hparams(hparams_string=None, verbose=False): decoder_rnn_dim=1024, prenet_dim=256, max_decoder_steps=1000, - gate_threshold=0.6, + gate_threshold=0.5, + p_attention_dropout=0.1, + p_decoder_dropout=0.1, # Attention parameters attention_rnn_dim=1024, @@ -78,9 +79,9 @@ def create_hparams(hparams_string=None, verbose=False): use_saved_learning_rate=False, learning_rate=1e-3, weight_decay=1e-6, - grad_clip_thresh=1, - batch_size=48, - mask_padding=False # set model's padded outputs to padded values + grad_clip_thresh=1.0, + batch_size=64, + mask_padding=True # set model's padded outputs to padded values ) if hparams_string: diff --git a/layers.py b/layers.py index f4935d5..615a64a 100644 --- a/layers.py +++ b/layers.py @@ -10,7 +10,7 @@ class LinearNorm(torch.nn.Module): super(LinearNorm, self).__init__() self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) - torch.nn.init.xavier_uniform( + torch.nn.init.xavier_uniform_( self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) @@ -31,7 +31,7 @@ class ConvNorm(torch.nn.Module): padding=padding, dilation=dilation, bias=bias) - torch.nn.init.xavier_uniform( + torch.nn.init.xavier_uniform_( self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) def forward(self, signal): @@ -42,7 +42,7 @@ class ConvNorm(torch.nn.Module): class TacotronSTFT(torch.nn.Module): def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, - mel_fmax=None): + mel_fmax=8000.0): super(TacotronSTFT, self).__init__() self.n_mel_channels = n_mel_channels self.sampling_rate = sampling_rate diff --git a/model.py b/model.py index 263faa6..6673b7c 100644 --- a/model.py +++ b/model.py @@ -1,3 +1,4 @@ +from math import sqrt import torch from torch.autograd import Variable from torch import nn @@ -56,7 +57,7 @@ class Attention(nn.Module): processed_query = self.query_layer(query.unsqueeze(1)) processed_attention_weights = self.location_layer(attention_weights_cat) - energies = self.v(F.tanh( + energies = self.v(torch.tanh( processed_query + processed_attention_weights + processed_memory)) energies = energies.squeeze(-1) @@ -107,7 +108,6 @@ class Postnet(nn.Module): def __init__(self, hparams): super(Postnet, self).__init__() - self.dropout = nn.Dropout(0.5) self.convolutions = nn.ModuleList() self.convolutions.append( @@ -141,9 +141,8 @@ class Postnet(nn.Module): def forward(self, x): for i in range(len(self.convolutions) - 1): - x = self.dropout(F.tanh(self.convolutions[i](x))) - - x = self.dropout(self.convolutions[-1](x)) + x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) + x = F.dropout(self.convolutions[-1](x), 0.5, self.training) return x @@ -155,7 +154,6 @@ class Encoder(nn.Module): """ def __init__(self, hparams): super(Encoder, self).__init__() - self.dropout = nn.Dropout(0.5) convolutions = [] for _ in range(hparams.encoder_n_convolutions): @@ -175,7 +173,7 @@ class Encoder(nn.Module): def forward(self, x, input_lengths): for conv in self.convolutions: - x = self.dropout(F.relu(conv(x))) + x = F.dropout(F.relu(conv(x)), 0.5, self.training) x = x.transpose(1, 2) @@ -194,7 +192,7 @@ class Encoder(nn.Module): def inference(self, x): for conv in self.convolutions: - x = self.dropout(F.relu(conv(x))) + x = F.dropout(F.relu(conv(x)), 0.5, self.training) x = x.transpose(1, 2) @@ -215,13 +213,15 @@ class Decoder(nn.Module): self.prenet_dim = hparams.prenet_dim self.max_decoder_steps = hparams.max_decoder_steps self.gate_threshold = hparams.gate_threshold + self.p_attention_dropout = hparams.p_attention_dropout + self.p_decoder_dropout = hparams.p_decoder_dropout self.prenet = Prenet( hparams.n_mel_channels * hparams.n_frames_per_step, [hparams.prenet_dim, hparams.prenet_dim]) self.attention_rnn = nn.LSTMCell( - hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, + hparams.prenet_dim + hparams.encoder_embedding_dim, hparams.attention_rnn_dim) self.attention_layer = Attention( @@ -230,12 +230,12 @@ class Decoder(nn.Module): hparams.attention_location_kernel_size) self.decoder_rnn = nn.LSTMCell( - hparams.prenet_dim + hparams.encoder_embedding_dim, + hparams.attention_rnn_dim + hparams.encoder_embedding_dim, hparams.decoder_rnn_dim, 1) self.linear_projection = LinearNorm( hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, - hparams.n_mel_channels*hparams.n_frames_per_step) + hparams.n_mel_channels * hparams.n_frames_per_step) self.gate_layer = LinearNorm( hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1, @@ -350,10 +350,13 @@ class Decoder(nn.Module): gate_output: gate output energies attention_weights: """ - - cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1) + cell_input = torch.cat((decoder_input, self.attention_context), -1) self.attention_hidden, self.attention_cell = self.attention_rnn( cell_input, (self.attention_hidden, self.attention_cell)) + self.attention_hidden = F.dropout( + self.attention_hidden, self.p_attention_dropout, self.training) + self.attention_cell = F.dropout( + self.attention_cell, self.p_attention_dropout, self.training) attention_weights_cat = torch.cat( (self.attention_weights.unsqueeze(1), @@ -363,10 +366,14 @@ class Decoder(nn.Module): attention_weights_cat, self.mask) self.attention_weights_cum += self.attention_weights - prenet_output = self.prenet(decoder_input) - decoder_input = torch.cat((prenet_output, self.attention_context), -1) + decoder_input = torch.cat( + (self.attention_hidden, self.attention_context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( decoder_input, (self.decoder_hidden, self.decoder_cell)) + self.decoder_hidden = F.dropout( + self.decoder_hidden, self.p_decoder_dropout, self.training) + self.decoder_cell = F.dropout( + self.decoder_cell, self.p_decoder_dropout, self.training) decoder_hidden_attention_context = torch.cat( (self.decoder_hidden, self.attention_context), dim=1) @@ -391,22 +398,23 @@ class Decoder(nn.Module): alignments: sequence of attention weights from the decoder """ - decoder_input = self.get_go_frame(memory) + decoder_input = self.get_go_frame(memory).unsqueeze(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) + decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) + decoder_inputs = self.prenet(decoder_inputs) + self.initialize_decoder_states( memory, mask=~get_mask_from_lengths(memory_lengths)) mel_outputs, gate_outputs, alignments = [], [], [] - - while len(mel_outputs) < decoder_inputs.size(0): + while len(mel_outputs) < decoder_inputs.size(0) - 1: + decoder_input = decoder_inputs[len(mel_outputs)] mel_output, gate_output, attention_weights = self.decode( decoder_input) - mel_outputs += [mel_output] - gate_outputs += [gate_output.squeeze(1)] + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output.squeeze()] alignments += [attention_weights] - decoder_input = decoder_inputs[len(mel_outputs) - 1] - mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( mel_outputs, gate_outputs, alignments) @@ -430,13 +438,14 @@ class Decoder(nn.Module): mel_outputs, gate_outputs, alignments = [], [], [] while True: + decoder_input = self.prenet(decoder_input) mel_output, gate_output, alignment = self.decode(decoder_input) - mel_outputs += [mel_output] - gate_outputs += [gate_output.squeeze(1)] + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output] alignments += [alignment] - if F.sigmoid(gate_output.data) > self.gate_threshold: + if torch.sigmoid(gate_output.data) > self.gate_threshold: break elif len(mel_outputs) == self.max_decoder_steps: print("Warning! Reached max decoder steps") @@ -459,8 +468,9 @@ class Tacotron2(nn.Module): self.n_frames_per_step = hparams.n_frames_per_step self.embedding = nn.Embedding( hparams.n_symbols, hparams.symbols_embedding_dim) - torch.nn.init.xavier_uniform_(self.embedding.weight.data) - + std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) + val = sqrt(3.0) * std # uniform bounds for std + self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(hparams) self.decoder = Decoder(hparams) self.postnet = Postnet(hparams) @@ -469,8 +479,8 @@ class Tacotron2(nn.Module): text_padded, input_lengths, mel_padded, gate_padded, \ output_lengths = batch text_padded = to_gpu(text_padded).long() - max_len = int(torch.max(input_lengths.data).numpy()) input_lengths = to_gpu(input_lengths).long() + max_len = torch.max(input_lengths.data).item() mel_padded = to_gpu(mel_padded).float() gate_padded = to_gpu(gate_padded).float() output_lengths = to_gpu(output_lengths).long() @@ -485,7 +495,7 @@ class Tacotron2(nn.Module): def parse_output(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: - mask = ~get_mask_from_lengths(output_lengths+1) # +1 token + mask = ~get_mask_from_lengths(output_lengths) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) @@ -494,7 +504,6 @@ class Tacotron2(nn.Module): outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs - return outputs def forward(self, inputs): @@ -512,14 +521,6 @@ class Tacotron2(nn.Module): mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet - # DataParallel expects equal sized inputs/outputs, hence padding - if input_lengths is not None: - alignments = alignments.unsqueeze(0) - alignments = nn.functional.pad( - alignments, - (0, max_len - alignments.size(3), 0, 0), - "constant", 0) - alignments = alignments.squeeze() return self.parse_output( [mel_outputs, mel_outputs_postnet, gate_outputs, alignments], output_lengths) diff --git a/stft.py b/stft.py index 8e137d3..03cd82f 100644 --- a/stft.py +++ b/stft.py @@ -61,7 +61,7 @@ class STFT(torch.nn.Module): np.linalg.pinv(scale * fourier_basis).T[:, None, :]) if window is not None: - assert(win_length >= filter_length) + assert(filter_length >= win_length) # get window and zero center pad it to filter_length fft_window = get_window(window, win_length, fftbins=True) fft_window = pad_center(fft_window, filter_length) diff --git a/train.py b/train.py index 4413549..7600ceb 100644 --- a/train.py +++ b/train.py @@ -5,9 +5,9 @@ import math from numpy import finfo import torch -from distributed import DistributedDataParallel +from distributed import apply_gradient_allreduce +import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler -from torch.nn import DataParallel from torch.utils.data import DataLoader from fp16_optimizer import FP16_Optimizer @@ -30,19 +30,20 @@ def batchnorm_to_float(module): def reduce_tensor(tensor, num_gpus): rt = tensor.clone() - torch.distributed.all_reduce(rt, op=torch.distributed.reduce_op.SUM) + dist.all_reduce(rt, op=dist.reduce_op.SUM) rt /= num_gpus return rt def init_distributed(hparams, n_gpus, rank, group_name): assert torch.cuda.is_available(), "Distributed mode requires CUDA." - print("Initializing distributed") + print("Initializing Distributed") + # Set cuda device so everything is done on the right GPU. torch.cuda.set_device(rank % torch.cuda.device_count()) # Initialize distributed communication - torch.distributed.init_process_group( + dist.init_process_group( backend=hparams.dist_backend, init_method=hparams.dist_url, world_size=n_gpus, rank=rank, group_name=group_name) @@ -131,22 +132,20 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus, pin_memory=False, collate_fn=collate_fn) val_loss = 0.0 - if distributed_run or torch.cuda.device_count() > 1: - batch_parser = model.module.parse_batch - else: - batch_parser = model.parse_batch - for i, batch in enumerate(val_loader): - x, y = batch_parser(batch) + x, y = model.parse_batch(batch) y_pred = model(x) loss = criterion(y_pred, y) - reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \ - if distributed_run else loss.data[0] + if distributed_run: + reduced_val_loss = reduce_tensor(loss.data, num_gpus).item() + else: + reduced_val_loss = loss.item() val_loss += reduced_val_loss val_loss = val_loss / (i + 1) model.train() - return val_loss + print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss)) + logger.log_validation(reduced_val_loss, model, y, y_pred, iteration) def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, @@ -176,6 +175,9 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, optimizer = FP16_Optimizer( optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling) + if hparams.distributed_run: + model = apply_gradient_allreduce(model) + criterion = Tacotron2Loss() logger = prepare_directories_and_logger( @@ -194,15 +196,10 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, checkpoint_path, model, optimizer) if hparams.use_saved_learning_rate: learning_rate = _learning_rate - iteration += 1 # next iteration is iteration + 1 epoch_offset = max(0, int(iteration / len(train_loader))) model.train() - if hparams.distributed_run or torch.cuda.device_count() > 1: - batch_parser = model.module.parse_batch - else: - batch_parser = model.parse_batch # ================ MAIN TRAINNIG LOOP! =================== for epoch in range(epoch_offset, hparams.epochs): print("Epoch: {}".format(epoch)) @@ -212,18 +209,21 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, param_group['lr'] = learning_rate model.zero_grad() - x, y = batch_parser(batch) + x, y = model.parse_batch(batch) y_pred = model(x) + loss = criterion(y_pred, y) - reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \ - if hparams.distributed_run else loss.data[0] + if hparams.distributed_run: + reduced_loss = reduce_tensor(loss.data, num_gpus).item() + else: + reduced_loss = loss.item() if hparams.fp16_run: optimizer.backward(loss) grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh) else: loss.backward() - grad_norm = torch.nn.utils.clip_grad_norm( + grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), hparams.grad_clip_thresh) optimizer.step() @@ -234,20 +234,14 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, duration = time.perf_counter() - start print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( iteration, reduced_loss, grad_norm, duration)) - logger.log_training( reduced_loss, grad_norm, learning_rate, duration, iteration) if not overflow and (iteration % hparams.iters_per_checkpoint == 0): - reduced_val_loss = validate( - model, criterion, valset, iteration, hparams.batch_size, - n_gpus, collate_fn, logger, hparams.distributed_run, rank) + validate(model, criterion, valset, iteration, hparams.batch_size, + n_gpus, collate_fn, logger, hparams.distributed_run, rank) if rank == 0: - print("Validation loss {}: {:9f} ".format( - iteration, reduced_val_loss)) - logger.log_validation( - reduced_val_loss, model, y, y_pred, iteration) checkpoint_path = os.path.join( output_directory, "checkpoint_{}".format(iteration)) save_checkpoint(model, optimizer, learning_rate, iteration, diff --git a/utils.py b/utils.py index 633ecff..c843d95 100644 --- a/utils.py +++ b/utils.py @@ -4,29 +4,26 @@ import torch def get_mask_from_lengths(lengths): - max_len = torch.max(lengths) - ids = torch.arange(0, max_len).long().cuda() + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) mask = (ids < lengths.unsqueeze(1)).byte() return mask -def load_wav_to_torch(full_path, sr): +def load_wav_to_torch(full_path): sampling_rate, data = read(full_path) - assert sr == sampling_rate, "{} SR doesn't match {} on path {}".format( - sr, sampling_rate, full_path) - return torch.FloatTensor(data.astype(np.float32)) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate -def load_filepaths_and_text(filename, sort_by_length, split="|"): +def load_filepaths_and_text(filename, split="|"): with open(filename, encoding='utf-8') as f: filepaths_and_text = [line.strip().split(split) for line in f] - - if sort_by_length: - filepaths_and_text.sort(key=lambda x: len(x[1])) - return filepaths_and_text def to_gpu(x): - x = x.contiguous().cuda(async=True) + x = x.contiguous() + + if torch.cuda.is_available(): + x = x.cuda(non_blocking=True) return torch.autograd.Variable(x) diff --git a/waveglow b/waveglow new file mode 160000 index 0000000..4b1001f --- /dev/null +++ b/waveglow @@ -0,0 +1 @@ +Subproject commit 4b1001fa3336a1184b8293745bb89b177457f09b