Spaces:
Runtime error
Runtime error
import torch | |
def create_grid_mask(seq_length, trunck_length, fill_triangle): | |
assert seq_length > 0 | |
# 先不考虑seen_length创建一个grid mask: | |
if fill_triangle: | |
mask = 1 - torch.triu(torch.ones(seq_length, seq_length), diagonal=1) | |
# 下三角与主对角线都为1 | |
else: | |
mask = torch.zeros(seq_length, seq_length) | |
for i in range(seq_length): | |
trunck_idx = i // trunck_length | |
trunck_start = trunck_idx * trunck_length | |
trunck_end = trunck_length + trunck_start | |
mask[i][trunck_start:trunck_end] = 1 | |
return mask | |
if __name__ == "__main__": | |
mask = create_grid_mask(seq_length=8, trunck_length=3, fill_triangle=True).int() | |
print(mask) | |
# tensor([[1, 1, 1, 0, 0, 0, 0, 0], | |
# [1, 1, 1, 0, 0, 0, 0, 0], | |
# [1, 1, 1, 0, 0, 0, 0, 0], | |
# [1, 1, 1, 1, 1, 1, 0, 0], | |
# [1, 1, 1, 1, 1, 1, 0, 0], | |
# [1, 1, 1, 1, 1, 1, 0, 0], | |
# [1, 1, 1, 1, 1, 1, 1, 1], | |
# [1, 1, 1, 1, 1, 1, 1, 1]] | |