Merge pull request #4 from NVIDIA/mask-utils-0.4

mask utils update for 0.4 cuda
padding-patch-0.4
Rafael Valle 2018-05-04 11:02:30 -07:00 committed by GitHub
commit 2c545ac800
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -5,7 +5,7 @@ import torch
def get_mask_from_lengths(lengths): def get_mask_from_lengths(lengths):
max_len = torch.max(lengths) max_len = torch.max(lengths)
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).cuda() ids = torch.arange(0, max_len).long().cuda()
mask = (ids < lengths.unsqueeze(1)).byte() mask = (ids < lengths.unsqueeze(1)).byte()
return mask return mask