lll2343 commited on
Commit
0bd3613
·
verified ·
1 Parent(s): 009f006

Upload attn_mask_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. attn_mask_utils.py +292 -0
attn_mask_utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+
4
+ def find_prefix_seq_length_by_pe(
5
+ pe: torch.Tensor
6
+ ) -> torch.Tensor:
7
+ """
8
+ Find the sequence length where position encoding drops (indicating prefix boundary).
9
+ Args:
10
+ pe: Position encoding tensor of shape [Batch size, Sequence length ]
11
+ Contains position indices for each token in the sequence.
12
+ Returns:
13
+ torch.Tensor: A tensor of shape [B] containing:
14
+ - The index where position encoding drops for each sequence
15
+ - -1 if no drop occurs in the sequence
16
+ """
17
+ batch_size, seq_len = pe.shape
18
+ prev = pe[:, :-1]
19
+ curr = pe[:, 1:]
20
+ drop_mask = curr < prev # [batch_size, seq_len-1]
21
+
22
+ seq_len = torch.full((batch_size,), -1, dtype=torch.long)
23
+
24
+ for b in range(batch_size):
25
+ drop_pos = torch.nonzero(drop_mask[b], as_tuple=False)
26
+ if drop_pos.numel() > 0:
27
+ i = drop_pos[0].item() + 1 # Take first drop position (+1 because we compared shifted sequences)
28
+ seq_len[b] = i
29
+
30
+ return seq_len
31
+
32
+
33
+
34
+ def update_causal_mask_with_pad_non_visible_2d(
35
+ input_ids: torch.Tensor,
36
+ attn_mask_2d: torch.Tensor,
37
+ text_mask_token_id: int = 151666,
38
+ block_size: int = 4,
39
+ causal_attn: bool = False
40
+ ) -> torch.Tensor:
41
+ """
42
+ Updates a 2D attention mask for hole sequence through input_ids and text_mask_token_id
43
+
44
+ Args:
45
+ input_ids: Input token IDs (unused in current implementation)
46
+ attn_mask_2d: 2D attention mask matrix of shape [seq_len, seq_len] where:
47
+ - 0.0 indicates allowed attention
48
+ - -inf indicates masked attention
49
+ text_mask_token_id: ID representing masked tokens
50
+ block_size: Size of the diffusion window
51
+ causal_attn: If True, maintains strict causal masking throughout
52
+
53
+ Returns:
54
+ Modified attention mask with updated visibility patterns
55
+ """
56
+ seq_len = input_ids.shape[0]
57
+ device = input_ids.device
58
+
59
+ # Identify masked tokens and their preceding positions
60
+ input_mask = input_ids.eq(text_mask_token_id)
61
+ input_before_mask = torch.zeros_like(input_mask)
62
+ input_before_mask[:-1] = input_mask[1:]
63
+ mask_cols = (input_mask | input_before_mask)
64
+ non_mask = ~mask_cols
65
+
66
+ rows = torch.arange(seq_len, device=device)[:, None] # (seq_len, 1)
67
+ cols = torch.arange(seq_len, device=device) # (seq_len,)
68
+
69
+
70
+ indices = torch.arange(seq_len, device=device)
71
+ prev_non_mask = (indices * non_mask).cummax(dim=0).values
72
+
73
+ max_value = torch.iinfo(indices.dtype).max
74
+ mask_indices = torch.where(non_mask, indices, torch.full_like(indices, max_value))
75
+ reversed_mask_indices = torch.flip(mask_indices, dims=[0])
76
+ reversed_cummin = reversed_mask_indices.cummin(dim=0).values
77
+ next_non_mask = torch.flip(reversed_cummin, dims=[0])
78
+
79
+ # ================= Part 1: Make positions after masks invisible =================
80
+ infra_mask = (
81
+ (cols > prev_non_mask) &
82
+ (rows >= next_non_mask[None, :]) &
83
+ mask_cols[None, :]
84
+ )
85
+ attn_mask_2d.masked_fill_(infra_mask, -float('inf'))
86
+
87
+ # ================= Part 2: Allow visibility to previous positions (if not causal) =================
88
+ if not causal_attn:
89
+ visible_mask = (
90
+ (rows > prev_non_mask[None, :]) &
91
+ (rows < cols) &
92
+ mask_cols[None, :]
93
+ )
94
+ attn_mask_2d.masked_fill_(visible_mask, 0.0)
95
+
96
+ return attn_mask_2d
97
+
98
+
99
+ def update_causal_mask_for_one_gen_window_2d(
100
+ input_ids: torch.Tensor,
101
+ attn_mask_2d: torch.Tensor,
102
+ block_size: int = 4,
103
+ use_cache: bool = True,
104
+ causal_attn: bool = False
105
+ ) -> torch.Tensor:
106
+ """
107
+ Updates a 2D attention mask for a diffusion window in transformer inference.
108
+
109
+ Args:
110
+ input_ids: Input token IDs (unused in current implementation)
111
+ attn_mask_2d: 2D attention mask matrix of shape [seq_len, seq_len] where:
112
+ - 0.0 indicates allowed attention
113
+ - -inf indicates masked attention
114
+ block_size: Size of the diffusion window
115
+ use_cache: Whether key-value cache is being used
116
+ causal_attn: If True, maintains strict causal masking throughout
117
+
118
+ Returns:
119
+ Modified attention mask with updated visibility patterns
120
+ """
121
+
122
+ if not causal_attn:
123
+ # Make the diffusion window (last block_size tokens) fully visible to itself
124
+ # This allows bidirectional attention within the diffusion window
125
+ attn_mask_2d[-block_size:, -block_size:] = 0.0
126
+ if use_cache:
127
+ # Mask the last token from previous round to prevent recomputation and maintain generation consistency.
128
+ attn_mask_2d[-block_size:, -block_size-1] = -float('inf')
129
+
130
+ return attn_mask_2d
131
+
132
+
133
+ def create_block_diff_mask_by_pe_1d(
134
+ b: int,
135
+ h: int,
136
+ q_idx: torch.Tensor,
137
+ kv_idx: torch.Tensor,
138
+ block_size: int,
139
+ x0_len_list: torch.Tensor,
140
+ position_ids_list: torch.Tensor,
141
+ causal_attn: bool = False,
142
+ ) -> torch.Tensor:
143
+ """Computes attention mask for a single query-key position in Flex Attention.
144
+
145
+ Args:
146
+ b (int): Batch index (0 <= b < batch_size).
147
+ h (int): Head index (unused in current implementation, reserved for future multi-head support).
148
+ q_idx (torch.Tensor): Query position index (scalar or 0D tensor).
149
+ kv_idx (torch.Tensor): Key/Value position index (scalar or 0D tensor).
150
+ block_size (int): Size of processing blocks for non-`x0` tokens.
151
+ x0_len_list (torch.Tensor): Tensor of shape [batch_size] with `x0` segment lengths.
152
+ position_ids_list (torch.Tensor): Tensor of shape [batch_size, seq_len] with position IDs.
153
+ causal_attn (bool, optional): Enforces causal masking in mutual blocks if True. Defaults to False.
154
+
155
+ Returns:
156
+ torch.Tensor: Boolean indicating whether attention is allowed (True = allowed).
157
+ """
158
+ x0_len = x0_len_list[b]
159
+ position_ids = position_ids_list[b]
160
+
161
+ x0_flag_q = (q_idx < x0_len)
162
+ x0_flag_kv = (kv_idx < x0_len)
163
+
164
+ # top - left causal
165
+ block_causal = (
166
+ x0_flag_q & \
167
+ x0_flag_kv & \
168
+ (q_idx >= kv_idx)
169
+ )
170
+
171
+ q_ith_block = (q_idx - x0_len) // block_size
172
+ kv_ith_block = (kv_idx - x0_len) // block_size
173
+
174
+ # bottom - right
175
+ block_mutual = (
176
+ (~x0_flag_q & ~x0_flag_kv) & \
177
+ (q_ith_block == kv_ith_block) & \
178
+ (q_idx >= kv_idx if causal_attn else 1)
179
+ )
180
+
181
+ # bottom - left
182
+ prefix_len = position_ids[x0_len + q_ith_block * block_size] # kv_idx's cosponding prefix
183
+ block_prefix = (
184
+ (~x0_flag_q & x0_flag_kv) & \
185
+ (kv_idx < prefix_len)
186
+ )
187
+
188
+ mask_val = (block_causal | block_mutual | block_prefix)
189
+ return mask_val.to(torch.bool)
190
+
191
+
192
+ def create_block_diff_mask_by_pe_4d(
193
+ block_size: int,
194
+ x0_len_list: torch.Tensor,
195
+ position_ids: torch.Tensor,
196
+ causal_attn: bool = False
197
+ ) -> tuple[torch.Tensor, torch.Tensor]:
198
+ """Generates a 4D attention mask for block-difference attention patterns.
199
+
200
+ The mask consists of three regions:
201
+ 1. Causal block (top-left): Standard causal attention for `x0` tokens.
202
+ 2. Mutual block (bottom-right): Non-causal attention within the same block for non-`x0` tokens.
203
+ 3. Prefix block (bottom-left): Non-`x0` tokens can attend to a prefix of `x0` tokens.
204
+
205
+ Args:
206
+ block_size (int): Size of processing blocks for non-`x0` tokens.
207
+ x0_len_list (torch.Tensor): Tensor of shape [B] containing lengths of `x0` segments per batch.
208
+ position_ids (torch.Tensor): Tensor of shape [B, seq_len] containing position IDs.
209
+ causal_attn (bool, optional): If True, enforces causal masking in mutual blocks. Defaults to False.
210
+
211
+ Returns:
212
+ tuple[torch.Tensor, torch.Tensor]:
213
+ - A float mask of shape [batch_size, 1, seq_len, seq_len] with `-inf` for masked positions (non visiable).
214
+ - A boolean mask of shape [batch_size, 1, seq_len, seq_len] indicating allowed attention positions.
215
+ """
216
+ batch_size, seq_len = position_ids.shape
217
+ device = position_ids.device
218
+
219
+ # Create position indices [batch_size, seq_len, seq_len]
220
+ q_idx = torch.arange(seq_len, device=device).view(1, seq_len, 1) # [1, seq_len, 1]
221
+ kv_idx = torch.arange(seq_len, device=device).view(1, 1, seq_len) # [1, 1, seq_len]
222
+
223
+ # Broadcast to [B, seq_len, seq_len]
224
+ x0_len = x0_len_list.view(batch_size, 1, 1) # [batch_size, 1, 1]
225
+ x0_flag_q = q_idx < x0_len # [batch_size, seq_len, seq_len]
226
+ x0_flag_kv = kv_idx < x0_len
227
+
228
+ # Block indices calculation [batch_size, seq_len, seq_len]
229
+ q_block_idx = (q_idx - x0_len) // block_size
230
+ kv_block_idx = (kv_idx - x0_len) // block_size
231
+
232
+ # causal block (top-left)
233
+ block_causal = x0_flag_q & x0_flag_kv & (q_idx >= kv_idx)
234
+
235
+ # Mutual block (bottom-right)
236
+ mutual_condition = (q_idx >= kv_idx) if causal_attn else torch.ones_like(q_idx, dtype=torch.bool)
237
+ block_mutual = (~x0_flag_q & ~x0_flag_kv &
238
+ (q_block_idx == kv_block_idx) &
239
+ mutual_condition)
240
+
241
+ # Prefix block (bottom-left)
242
+ q_blk = torch.div(q_idx - x0_len, block_size, rounding_mode='floor')
243
+ q_blk_start = (x0_len_list.view(batch_size, 1) + q_blk[:, :, 0] * block_size).clamp(min=0, max=seq_len-1) # (batch_size, L)
244
+ prefix_len = position_ids.gather(1, q_blk_start)
245
+ prefix_len = prefix_len.unsqueeze(2)
246
+ block_prefix = (~x0_flag_q & x0_flag_kv) & (kv_idx < prefix_len)
247
+
248
+ # FIXME Padding Mask
249
+ # padding_mask = (position_ids.view(batch_size, 1, seq_len) != -1) & (position_ids.view(batch_size, seq_len, -1) != -1)
250
+
251
+ # Combine masks
252
+ final_mask = (block_causal | block_mutual | block_prefix) # bool
253
+ # & padding_mask
254
+ customized_mask = torch.full_like(final_mask, float('-inf'), dtype=torch.bfloat16)
255
+ customized_mask.masked_fill_(final_mask, 0.0) # 0.0 or -inf
256
+
257
+ # Add head dimension [batch_size, 1, seq_len, seq_len]
258
+ return customized_mask.unsqueeze(1).to(device=device), final_mask.unsqueeze(1).to(device=device)
259
+
260
+
261
+ def find_pred_pos_from_input_ids(
262
+ input_ids: torch.LongTensor = None,
263
+ text_mask_token_id: int = 151666,
264
+ ) -> torch.Tensor:
265
+ """Compute the relative prediction positions for masked tokens in a sequence.
266
+
267
+ For non-masked positions, the output is 0. For masked positions, the value increments
268
+ by 1 for each consecutive mask token, indicating how many steps ahead the prediction is.
269
+
270
+ Args:
271
+ input_ids (torch.LongTensor): Input token IDs of shape [batch_size, seq_len].
272
+ text_mask_token_id (int, optional): Token ID representing masked positions. Defaults to 151666.
273
+
274
+ Returns:
275
+ torch.Tensor: A tensor of shape [batch_size, seq_len] where:
276
+ - 0 indicates a non-masked token.
277
+ - n > 0 indicates the nth consecutive masked token (e.g., 1 = first mask, 2 = second mask, etc.).
278
+ """
279
+ batch_size, seq_len = input_ids.shape
280
+ device = input_ids.device
281
+
282
+ is_mask = (input_ids == text_mask_token_id)
283
+
284
+ base_mask = torch.zeros((batch_size, seq_len), dtype=torch.int8, device=device)
285
+
286
+ for b in range(batch_size):
287
+ for ix in range(1, seq_len):
288
+ if is_mask[b][ix] == True:
289
+ # Increment counter if current token is masked
290
+ base_mask[b][ix] = base_mask[b][ix-1] + 1
291
+
292
+ return base_mask