import torch def create_grid_mask(seq_length, trunck_length, fill_triangle): assert seq_length > 0 # 先不考虑seen_length创建一个grid mask: if fill_triangle: mask = 1 - torch.triu(torch.ones(seq_length, seq_length), diagonal=1) # 下三角与主对角线都为1 else: mask = torch.zeros(seq_length, seq_length) for i in range(seq_length): trunck_idx = i // trunck_length trunck_start = trunck_idx * trunck_length trunck_end = trunck_length + trunck_start mask[i][trunck_start:trunck_end] = 1 return mask if __name__ == "__main__": mask = create_grid_mask(seq_length=8, trunck_length=3, fill_triangle=True).int() print(mask) # tensor([[1, 1, 1, 0, 0, 0, 0, 0], # [1, 1, 1, 0, 0, 0, 0, 0], # [1, 1, 1, 0, 0, 0, 0, 0], # [1, 1, 1, 1, 1, 1, 0, 0], # [1, 1, 1, 1, 1, 1, 0, 0], # [1, 1, 1, 1, 1, 1, 0, 0], # [1, 1, 1, 1, 1, 1, 1, 1], # [1, 1, 1, 1, 1, 1, 1, 1]]