1
0
mirror of https://github.com/malarinv/tacotron2 synced 2026-03-08 01:32:35 +00:00

mask utils update for 0.4 cuda

This commit is contained in:
Raul Puri
2018-05-04 10:14:30 -07:00
committed by GitHub
parent c141726a96
commit 6fbba8ef0f

View File

@@ -5,7 +5,7 @@ import torch
def get_mask_from_lengths(lengths):
max_len = torch.max(lengths)
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).cuda()
ids = torch.arange(0, max_len).long().cuda()
mask = (ids < lengths.unsqueeze(1)).byte()
return mask