mirror of
https://github.com/malarinv/tacotron2
synced 2026-03-08 01:32:35 +00:00
Merge pull request #4 from NVIDIA/mask-utils-0.4
mask utils update for 0.4 cuda
This commit is contained in:
2
utils.py
2
utils.py
@@ -5,7 +5,7 @@ import torch
|
|||||||
|
|
||||||
def get_mask_from_lengths(lengths):
|
def get_mask_from_lengths(lengths):
|
||||||
max_len = torch.max(lengths)
|
max_len = torch.max(lengths)
|
||||||
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len)).cuda()
|
ids = torch.arange(0, max_len).long().cuda()
|
||||||
mask = (ids < lengths.unsqueeze(1)).byte()
|
mask = (ids < lengths.unsqueeze(1)).byte()
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user