mirror of https://github.com/malarinv/tacotron2
mask utils update for 0.4 cuda
parent
c141726a96
commit
6fbba8ef0f
2
utils.py
2
utils.py
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
|
|
||||||
def get_mask_from_lengths(lengths):
|
def get_mask_from_lengths(lengths):
|
||||||
max_len = torch.max(lengths)
|
max_len = torch.max(lengths)
|
||||||
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).cuda()
|
ids = torch.arange(0, max_len).long().cuda()
|
||||||
mask = (ids < lengths.unsqueeze(1)).byte()
|
mask = (ids < lengths.unsqueeze(1)).byte()
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue