fix bert_padding
Browse files- __pycache__/attention.cpython-311.pyc +0 -0
- __pycache__/layers.cpython-311.pyc +0 -0
- attention.py +13 -13
- layers.py +5 -5
__pycache__/attention.cpython-311.pyc
CHANGED
Binary files a/__pycache__/attention.cpython-311.pyc and b/__pycache__/attention.cpython-311.pyc differ
|
|
__pycache__/layers.cpython-311.pyc
CHANGED
Binary files a/__pycache__/layers.cpython-311.pyc and b/__pycache__/layers.cpython-311.pyc differ
|
|
attention.py
CHANGED
@@ -24,7 +24,7 @@ import sys
|
|
24 |
import os
|
25 |
# Add src folder root to path to allow us to use relative imports regardless of what directory the script is run from
|
26 |
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
27 |
-
import
|
28 |
from .configuration_bert import FlexBertConfig, maybe_add_padding
|
29 |
from .normalization import get_norm_layer
|
30 |
from .initialization import ModuleType, init_weights
|
@@ -161,7 +161,7 @@ class BertAlibiUnpadSelfAttention(nn.Module):
|
|
161 |
alibi_slopes=slopes,
|
162 |
)
|
163 |
else:
|
164 |
-
qkv =
|
165 |
unpad_bs, *_ = qkv.shape
|
166 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
|
167 |
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
|
@@ -174,7 +174,7 @@ class BertAlibiUnpadSelfAttention(nn.Module):
|
|
174 |
attention_probs = self.dropout(attention_probs)
|
175 |
attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d
|
176 |
|
177 |
-
attention =
|
178 |
|
179 |
return attention.view(bs, dim)
|
180 |
|
@@ -240,8 +240,8 @@ class BertAlibiUnpadAttention(nn.Module):
|
|
240 |
self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
|
241 |
if subset_idx is not None:
|
242 |
return self.output(
|
243 |
-
|
244 |
-
|
245 |
)
|
246 |
else:
|
247 |
return self.output(self_output, input_tensor)
|
@@ -415,7 +415,7 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
|
|
415 |
)
|
416 |
attn = attn.view(bs, dim)
|
417 |
else:
|
418 |
-
qkv =
|
419 |
unpad_bs, seqlen, _ = qkv.shape
|
420 |
|
421 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
|
@@ -430,7 +430,7 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
|
|
430 |
else None,
|
431 |
)
|
432 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
433 |
-
attn =
|
434 |
|
435 |
return self.out_drop(self.Wo(attn))
|
436 |
|
@@ -565,7 +565,7 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
|
|
565 |
)
|
566 |
attn = attn.view(bs, dim)
|
567 |
else:
|
568 |
-
qkv =
|
569 |
unpad_bs, seqlen, _ = qkv.shape
|
570 |
|
571 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
|
@@ -580,7 +580,7 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
|
|
580 |
else None,
|
581 |
)
|
582 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
583 |
-
attn =
|
584 |
|
585 |
return self.out_drop(self.Wo(attn.view(bs, dim)))
|
586 |
|
@@ -913,7 +913,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
913 |
)
|
914 |
attn = attn.view(bs, dim)
|
915 |
else:
|
916 |
-
qkv =
|
917 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
918 |
) # batch, max_seqlen, thd
|
919 |
unpad_bs, seqlen, *_ = qkv.shape
|
@@ -929,7 +929,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
929 |
else None,
|
930 |
)
|
931 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
932 |
-
attn =
|
933 |
|
934 |
return self.out_drop(self.Wo(attn))
|
935 |
|
@@ -1244,7 +1244,7 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
|
|
1244 |
)
|
1245 |
attn = attn.view(bs, dim)
|
1246 |
else:
|
1247 |
-
qkv =
|
1248 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
1249 |
) # batch, max_seqlen, thd
|
1250 |
unpad_bs, seqlen, *_ = qkv.shape
|
@@ -1260,7 +1260,7 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
|
|
1260 |
else None,
|
1261 |
)
|
1262 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
1263 |
-
attn =
|
1264 |
|
1265 |
return self.out_drop(self.Wo(attn))
|
1266 |
|
|
|
24 |
import os
|
25 |
# Add src folder root to path to allow us to use relative imports regardless of what directory the script is run from
|
26 |
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
27 |
+
from .bert_padding import pad_input, unpad_input_only, index_first_axis
|
28 |
from .configuration_bert import FlexBertConfig, maybe_add_padding
|
29 |
from .normalization import get_norm_layer
|
30 |
from .initialization import ModuleType, init_weights
|
|
|
161 |
alibi_slopes=slopes,
|
162 |
)
|
163 |
else:
|
164 |
+
qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
165 |
unpad_bs, *_ = qkv.shape
|
166 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
|
167 |
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
|
|
|
174 |
attention_probs = self.dropout(attention_probs)
|
175 |
attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d
|
176 |
|
177 |
+
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
178 |
|
179 |
return attention.view(bs, dim)
|
180 |
|
|
|
240 |
self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
|
241 |
if subset_idx is not None:
|
242 |
return self.output(
|
243 |
+
index_first_axis(self_output, subset_idx),
|
244 |
+
index_first_axis(input_tensor, subset_idx),
|
245 |
)
|
246 |
else:
|
247 |
return self.output(self_output, input_tensor)
|
|
|
415 |
)
|
416 |
attn = attn.view(bs, dim)
|
417 |
else:
|
418 |
+
qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
419 |
unpad_bs, seqlen, _ = qkv.shape
|
420 |
|
421 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
|
|
|
430 |
else None,
|
431 |
)
|
432 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
433 |
+
attn = unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
|
434 |
|
435 |
return self.out_drop(self.Wo(attn))
|
436 |
|
|
|
565 |
)
|
566 |
attn = attn.view(bs, dim)
|
567 |
else:
|
568 |
+
qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
569 |
unpad_bs, seqlen, _ = qkv.shape
|
570 |
|
571 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
|
|
|
580 |
else None,
|
581 |
)
|
582 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
583 |
+
attn = unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
|
584 |
|
585 |
return self.out_drop(self.Wo(attn.view(bs, dim)))
|
586 |
|
|
|
913 |
)
|
914 |
attn = attn.view(bs, dim)
|
915 |
else:
|
916 |
+
qkv = pad_input(
|
917 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
918 |
) # batch, max_seqlen, thd
|
919 |
unpad_bs, seqlen, *_ = qkv.shape
|
|
|
929 |
else None,
|
930 |
)
|
931 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
932 |
+
attn = unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
|
933 |
|
934 |
return self.out_drop(self.Wo(attn))
|
935 |
|
|
|
1244 |
)
|
1245 |
attn = attn.view(bs, dim)
|
1246 |
else:
|
1247 |
+
qkv = pad_input(
|
1248 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
1249 |
) # batch, max_seqlen, thd
|
1250 |
unpad_bs, seqlen, *_ = qkv.shape
|
|
|
1260 |
else None,
|
1261 |
)
|
1262 |
attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
|
1263 |
+
attn = unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
|
1264 |
|
1265 |
return self.out_drop(self.Wo(attn))
|
1266 |
|
layers.py
CHANGED
@@ -20,7 +20,7 @@ from typing import Optional, Union, List
|
|
20 |
import torch
|
21 |
import torch.nn as nn
|
22 |
|
23 |
-
import
|
24 |
|
25 |
from .activation import get_act_fn
|
26 |
from .attention import FlexBertAttentionBase, BertAlibiUnpadAttention, get_attention_layer
|
@@ -155,7 +155,7 @@ class BertAlibiEncoder(nn.Module):
|
|
155 |
# and ntokens_unpad is total number of non-padded tokens.
|
156 |
# Then unpadding performs the following compression of the inputs:
|
157 |
# hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
|
158 |
-
hidden_states, indices, cu_seqlens, _ =
|
159 |
|
160 |
# Add alibi matrix to extended_attention_mask
|
161 |
if self._current_alibi_size < seqlen:
|
@@ -190,7 +190,7 @@ class BertAlibiEncoder(nn.Module):
|
|
190 |
# and ntokens_unpad is total number of non-padded tokens.
|
191 |
# Then padding performs the following de-compression:
|
192 |
# hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
|
193 |
-
hidden_states =
|
194 |
else:
|
195 |
for i in range(len(self.layer) - 1):
|
196 |
layer_module = self.layer[i]
|
@@ -636,7 +636,7 @@ class FlexBertUnpadEncoder(FlexBertEncoderBase):
|
|
636 |
if indices is None and cu_seqlens is None and max_seqlen is None:
|
637 |
attention_mask_bool = attention_mask.bool()
|
638 |
batch, seqlen = hidden_states.shape[:2]
|
639 |
-
hidden_states, indices, cu_seqlens, max_seqlen =
|
640 |
hidden_states, attention_mask_bool
|
641 |
)
|
642 |
|
@@ -649,7 +649,7 @@ class FlexBertUnpadEncoder(FlexBertEncoderBase):
|
|
649 |
attn_mask=attention_mask,
|
650 |
)
|
651 |
|
652 |
-
return
|
653 |
else:
|
654 |
for layer_module in self.layers:
|
655 |
hidden_states = layer_module(
|
|
|
20 |
import torch
|
21 |
import torch.nn as nn
|
22 |
|
23 |
+
from .bert_padding import unpad_input, pad_input
|
24 |
|
25 |
from .activation import get_act_fn
|
26 |
from .attention import FlexBertAttentionBase, BertAlibiUnpadAttention, get_attention_layer
|
|
|
155 |
# and ntokens_unpad is total number of non-padded tokens.
|
156 |
# Then unpadding performs the following compression of the inputs:
|
157 |
# hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
|
158 |
+
hidden_states, indices, cu_seqlens, _ = unpad_input(hidden_states, attention_mask_bool)
|
159 |
|
160 |
# Add alibi matrix to extended_attention_mask
|
161 |
if self._current_alibi_size < seqlen:
|
|
|
190 |
# and ntokens_unpad is total number of non-padded tokens.
|
191 |
# Then padding performs the following de-compression:
|
192 |
# hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
|
193 |
+
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
194 |
else:
|
195 |
for i in range(len(self.layer) - 1):
|
196 |
layer_module = self.layer[i]
|
|
|
636 |
if indices is None and cu_seqlens is None and max_seqlen is None:
|
637 |
attention_mask_bool = attention_mask.bool()
|
638 |
batch, seqlen = hidden_states.shape[:2]
|
639 |
+
hidden_states, indices, cu_seqlens, max_seqlen = unpad_input(
|
640 |
hidden_states, attention_mask_bool
|
641 |
)
|
642 |
|
|
|
649 |
attn_mask=attention_mask,
|
650 |
)
|
651 |
|
652 |
+
return pad_input(hidden_states, indices, batch, seqlen)
|
653 |
else:
|
654 |
for layer_module in self.layers:
|
655 |
hidden_states = layer_module(
|