oweller2
commited on
Commit
•
e0229bb
1
Parent(s):
1f59624
added updated code:
Browse files- __init__.py +3 -3
- attention.py +48 -38
- config.json +1 -1
- configuration_bert.py +3 -3
- modeling_flexbert.py +25 -41
__init__.py
CHANGED
@@ -33,13 +33,14 @@ from .modeling_flexbert import (
|
|
33 |
FlexBertForMaskedLM,
|
34 |
FlexBertForSequenceClassification,
|
35 |
FlexBertForMultipleChoice,
|
36 |
-
|
37 |
)
|
38 |
from .bert_padding import(
|
39 |
IndexFirstAxis,
|
40 |
IndexPutFirstAxis
|
41 |
)
|
42 |
|
|
|
43 |
__all__ = [
|
44 |
"BertAlibiEmbeddings",
|
45 |
"BertAlibiEncoder",
|
@@ -69,6 +70,5 @@ __all__ = [
|
|
69 |
"FlexBertForMaskedLM",
|
70 |
"FlexBertForSequenceClassification",
|
71 |
"FlexBertForMultipleChoice",
|
72 |
-
"
|
73 |
-
"IndexPutFirstAxis"
|
74 |
]
|
|
|
33 |
FlexBertForMaskedLM,
|
34 |
FlexBertForSequenceClassification,
|
35 |
FlexBertForMultipleChoice,
|
36 |
+
FlexBertForCausalLM,
|
37 |
)
|
38 |
from .bert_padding import(
|
39 |
IndexFirstAxis,
|
40 |
IndexPutFirstAxis
|
41 |
)
|
42 |
|
43 |
+
|
44 |
__all__ = [
|
45 |
"BertAlibiEmbeddings",
|
46 |
"BertAlibiEncoder",
|
|
|
70 |
"FlexBertForMaskedLM",
|
71 |
"FlexBertForSequenceClassification",
|
72 |
"FlexBertForMultipleChoice",
|
73 |
+
"FlexBertForCausalLM"
|
|
|
74 |
]
|
attention.py
CHANGED
@@ -74,7 +74,7 @@ class BertAlibiUnpadSelfAttention(nn.Module):
|
|
74 |
f"heads ({config.num_attention_heads})"
|
75 |
)
|
76 |
|
77 |
-
self.
|
78 |
self.num_attention_heads = config.num_attention_heads
|
79 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
80 |
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
@@ -145,7 +145,7 @@ class BertAlibiUnpadSelfAttention(nn.Module):
|
|
145 |
dropout_p=self.p_dropout,
|
146 |
deterministic=self.deterministic_fa2,
|
147 |
alibi_slopes=slopes,
|
148 |
-
|
149 |
)
|
150 |
attention = attention.to(orig_dtype) # type: ignore
|
151 |
else:
|
@@ -156,10 +156,11 @@ class BertAlibiUnpadSelfAttention(nn.Module):
|
|
156 |
dropout_p=self.p_dropout,
|
157 |
deterministic=self.deterministic_fa2,
|
158 |
alibi_slopes=slopes,
|
159 |
-
|
160 |
)
|
161 |
else:
|
162 |
-
assert not self.
|
|
|
163 |
qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
164 |
unpad_bs, *_ = qkv.shape
|
165 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
|
@@ -236,6 +237,7 @@ class BertAlibiUnpadAttention(nn.Module):
|
|
236 |
slopes: None or (batch, heads) or (heads,)
|
237 |
"""
|
238 |
assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}"
|
|
|
239 |
self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
|
240 |
if subset_idx is not None:
|
241 |
return self.output(
|
@@ -293,7 +295,7 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
|
|
293 |
f"heads ({config.num_attention_heads})"
|
294 |
)
|
295 |
|
296 |
-
self.
|
297 |
self.num_attention_heads = config.num_attention_heads
|
298 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
299 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
@@ -402,7 +404,7 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
|
|
402 |
dropout_p=self.p_dropout,
|
403 |
deterministic=self.deterministic_fa2,
|
404 |
window_size=self.sliding_window,
|
405 |
-
|
406 |
)
|
407 |
attn = attn.to(orig_dtype) # type: ignore
|
408 |
else:
|
@@ -413,11 +415,12 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
|
|
413 |
dropout_p=self.p_dropout,
|
414 |
deterministic=self.deterministic_fa2,
|
415 |
window_size=self.sliding_window,
|
416 |
-
|
417 |
)
|
418 |
attn = attn.view(bs, dim)
|
419 |
else:
|
420 |
-
assert not self.
|
|
|
421 |
qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
422 |
unpad_bs, seqlen, _ = qkv.shape
|
423 |
|
@@ -456,7 +459,7 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
|
|
456 |
f"heads ({config.num_attention_heads})"
|
457 |
)
|
458 |
|
459 |
-
self.
|
460 |
self.num_attention_heads = config.num_attention_heads
|
461 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
462 |
self.hidden_size = config.hidden_size
|
@@ -556,7 +559,7 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
|
|
556 |
dropout_p=self.p_dropout,
|
557 |
deterministic=self.deterministic_fa2,
|
558 |
window_size=self.sliding_window,
|
559 |
-
|
560 |
)
|
561 |
attn = attn.to(orig_dtype) # type: ignore
|
562 |
else:
|
@@ -567,11 +570,12 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
|
|
567 |
dropout_p=self.p_dropout,
|
568 |
deterministic=self.deterministic_fa2,
|
569 |
window_size=self.sliding_window,
|
570 |
-
|
571 |
)
|
572 |
attn = attn.view(bs, dim)
|
573 |
else:
|
574 |
-
assert not self.
|
|
|
575 |
qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
576 |
unpad_bs, seqlen, _ = qkv.shape
|
577 |
|
@@ -610,7 +614,7 @@ class FlexBertPaddedAttention(FlexBertAttentionBase):
|
|
610 |
f"heads ({config.num_attention_heads})"
|
611 |
)
|
612 |
|
613 |
-
self.
|
614 |
self.num_attention_heads = config.num_attention_heads
|
615 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
616 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
@@ -695,7 +699,7 @@ class FlexBertPaddedAttention(FlexBertAttentionBase):
|
|
695 |
dropout_p=self.p_dropout,
|
696 |
deterministic=self.deterministic_fa2,
|
697 |
window_size=self.sliding_window,
|
698 |
-
|
699 |
)
|
700 |
attn = attn.to(orig_dtype) # type: ignore
|
701 |
else:
|
@@ -704,10 +708,11 @@ class FlexBertPaddedAttention(FlexBertAttentionBase):
|
|
704 |
dropout_p=self.p_dropout,
|
705 |
deterministic=self.deterministic_fa2,
|
706 |
window_size=self.sliding_window,
|
707 |
-
|
708 |
)
|
709 |
else:
|
710 |
-
assert not self.
|
|
|
711 |
qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
|
712 |
|
713 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2)
|
@@ -743,7 +748,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
743 |
f"heads ({config.num_attention_heads})"
|
744 |
)
|
745 |
|
746 |
-
self.
|
747 |
self.num_attention_heads = config.num_attention_heads
|
748 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
749 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
@@ -882,7 +887,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
882 |
max_seqlen_q=max_seqlen,
|
883 |
max_seqlen_k=max_seqlen,
|
884 |
deterministic=self.deterministic_fa2,
|
885 |
-
causal=self.
|
886 |
)
|
887 |
attn = attn.to(orig_dtype) # type: ignore
|
888 |
else:
|
@@ -896,7 +901,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
896 |
max_seqlen_q=max_seqlen,
|
897 |
max_seqlen_k=max_seqlen,
|
898 |
deterministic=self.deterministic_fa2,
|
899 |
-
causal=self.
|
900 |
)
|
901 |
attn = attn.view(bs, dim)
|
902 |
elif self.use_fa2:
|
@@ -914,7 +919,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
914 |
dropout_p=self.p_dropout,
|
915 |
deterministic=self.deterministic_fa2,
|
916 |
window_size=self.sliding_window,
|
917 |
-
causal=self.
|
918 |
)
|
919 |
attn = attn.to(orig_dtype) # type: ignore
|
920 |
else:
|
@@ -925,11 +930,12 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
|
|
925 |
dropout_p=self.p_dropout,
|
926 |
deterministic=self.deterministic_fa2,
|
927 |
window_size=self.sliding_window,
|
928 |
-
causal=self.
|
929 |
)
|
930 |
attn = attn.view(bs, dim)
|
931 |
else:
|
932 |
-
assert not self.
|
|
|
933 |
qkv = bert_padding.pad_input(
|
934 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
935 |
) # batch, max_seqlen, thd
|
@@ -969,7 +975,7 @@ class FlexBertPaddedRopeAttention(FlexBertAttentionBase):
|
|
969 |
f"heads ({config.num_attention_heads})"
|
970 |
)
|
971 |
|
972 |
-
self.
|
973 |
self.num_attention_heads = config.num_attention_heads
|
974 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
975 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
@@ -1080,7 +1086,7 @@ class FlexBertPaddedRopeAttention(FlexBertAttentionBase):
|
|
1080 |
dropout_p=self.p_dropout,
|
1081 |
deterministic=self.deterministic_fa2,
|
1082 |
window_size=self.sliding_window,
|
1083 |
-
|
1084 |
)
|
1085 |
attn = attn.to(orig_dtype) # type: ignore
|
1086 |
else:
|
@@ -1089,10 +1095,11 @@ class FlexBertPaddedRopeAttention(FlexBertAttentionBase):
|
|
1089 |
dropout_p=self.p_dropout,
|
1090 |
deterministic=self.deterministic_fa2,
|
1091 |
window_size=self.sliding_window,
|
1092 |
-
|
1093 |
)
|
1094 |
else:
|
1095 |
-
assert not self.
|
|
|
1096 |
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
|
1097 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2)
|
1098 |
attn = F.scaled_dot_product_attention(
|
@@ -1127,7 +1134,7 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
|
|
1127 |
f"heads ({config.num_attention_heads})"
|
1128 |
)
|
1129 |
|
1130 |
-
self.
|
1131 |
self.num_attention_heads = config.num_attention_heads
|
1132 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
1133 |
self.hidden_size = config.hidden_size
|
@@ -1253,7 +1260,7 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
|
|
1253 |
dropout_p=self.p_dropout,
|
1254 |
deterministic=self.deterministic_fa2,
|
1255 |
window_size=self.sliding_window,
|
1256 |
-
|
1257 |
)
|
1258 |
attn = attn.to(orig_dtype) # type: ignore
|
1259 |
else:
|
@@ -1264,11 +1271,12 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
|
|
1264 |
dropout_p=self.p_dropout,
|
1265 |
deterministic=self.deterministic_fa2,
|
1266 |
window_size=self.sliding_window,
|
1267 |
-
|
1268 |
)
|
1269 |
attn = attn.view(bs, dim)
|
1270 |
else:
|
1271 |
-
assert not self.
|
|
|
1272 |
qkv = bert_padding.pad_input(
|
1273 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
1274 |
) # batch, max_seqlen, thd
|
@@ -1308,7 +1316,7 @@ class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase):
|
|
1308 |
f"heads ({config.num_attention_heads})"
|
1309 |
)
|
1310 |
|
1311 |
-
self.
|
1312 |
self.num_attention_heads = config.num_attention_heads
|
1313 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
1314 |
self.hidden_size = config.hidden_size
|
@@ -1413,7 +1421,7 @@ class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase):
|
|
1413 |
dropout_p=self.p_dropout,
|
1414 |
deterministic=self.deterministic_fa2,
|
1415 |
window_size=self.sliding_window,
|
1416 |
-
|
1417 |
)
|
1418 |
attn = attn.to(orig_dtype) # type: ignore
|
1419 |
else:
|
@@ -1422,10 +1430,11 @@ class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase):
|
|
1422 |
dropout_p=self.p_dropout,
|
1423 |
deterministic=self.deterministic_fa2,
|
1424 |
window_size=self.sliding_window,
|
1425 |
-
|
1426 |
)
|
1427 |
else:
|
1428 |
-
assert not self.
|
|
|
1429 |
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
|
1430 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2)
|
1431 |
attn = F.scaled_dot_product_attention(
|
@@ -1460,7 +1469,7 @@ class FlexBertPaddedParallelAttention(FlexBertAttentionBase):
|
|
1460 |
f"heads ({config.num_attention_heads})"
|
1461 |
)
|
1462 |
|
1463 |
-
self.
|
1464 |
self.num_attention_heads = config.num_attention_heads
|
1465 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
1466 |
self.hidden_size = config.hidden_size
|
@@ -1537,7 +1546,7 @@ class FlexBertPaddedParallelAttention(FlexBertAttentionBase):
|
|
1537 |
dropout_p=self.p_dropout,
|
1538 |
deterministic=self.deterministic_fa2,
|
1539 |
window_size=self.sliding_window,
|
1540 |
-
|
1541 |
)
|
1542 |
attn = attn.to(orig_dtype) # type: ignore
|
1543 |
else:
|
@@ -1546,10 +1555,11 @@ class FlexBertPaddedParallelAttention(FlexBertAttentionBase):
|
|
1546 |
dropout_p=self.p_dropout,
|
1547 |
deterministic=self.deterministic_fa2,
|
1548 |
window_size=self.sliding_window,
|
1549 |
-
|
1550 |
)
|
1551 |
else:
|
1552 |
-
assert not self.
|
|
|
1553 |
qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
|
1554 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
|
1555 |
attn = F.scaled_dot_product_attention(
|
|
|
74 |
f"heads ({config.num_attention_heads})"
|
75 |
)
|
76 |
|
77 |
+
self.is_causal = config.causal_mask
|
78 |
self.num_attention_heads = config.num_attention_heads
|
79 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
80 |
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
|
145 |
dropout_p=self.p_dropout,
|
146 |
deterministic=self.deterministic_fa2,
|
147 |
alibi_slopes=slopes,
|
148 |
+
causal=self.is_causal
|
149 |
)
|
150 |
attention = attention.to(orig_dtype) # type: ignore
|
151 |
else:
|
|
|
156 |
dropout_p=self.p_dropout,
|
157 |
deterministic=self.deterministic_fa2,
|
158 |
alibi_slopes=slopes,
|
159 |
+
causal = self.is_causal
|
160 |
)
|
161 |
else:
|
162 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
163 |
+
assert False
|
164 |
qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
165 |
unpad_bs, *_ = qkv.shape
|
166 |
qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
|
|
|
237 |
slopes: None or (batch, heads) or (heads,)
|
238 |
"""
|
239 |
assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}"
|
240 |
+
assert False
|
241 |
self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
|
242 |
if subset_idx is not None:
|
243 |
return self.output(
|
|
|
295 |
f"heads ({config.num_attention_heads})"
|
296 |
)
|
297 |
|
298 |
+
self.is_causal = config.causal_mask
|
299 |
self.num_attention_heads = config.num_attention_heads
|
300 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
301 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
|
|
404 |
dropout_p=self.p_dropout,
|
405 |
deterministic=self.deterministic_fa2,
|
406 |
window_size=self.sliding_window,
|
407 |
+
causal=self.is_causal
|
408 |
)
|
409 |
attn = attn.to(orig_dtype) # type: ignore
|
410 |
else:
|
|
|
415 |
dropout_p=self.p_dropout,
|
416 |
deterministic=self.deterministic_fa2,
|
417 |
window_size=self.sliding_window,
|
418 |
+
causal=self.is_causal
|
419 |
)
|
420 |
attn = attn.view(bs, dim)
|
421 |
else:
|
422 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
423 |
+
assert False
|
424 |
qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
425 |
unpad_bs, seqlen, _ = qkv.shape
|
426 |
|
|
|
459 |
f"heads ({config.num_attention_heads})"
|
460 |
)
|
461 |
|
462 |
+
self.is_causal = config.causal_mask
|
463 |
self.num_attention_heads = config.num_attention_heads
|
464 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
465 |
self.hidden_size = config.hidden_size
|
|
|
559 |
dropout_p=self.p_dropout,
|
560 |
deterministic=self.deterministic_fa2,
|
561 |
window_size=self.sliding_window,
|
562 |
+
causal=self.is_causal
|
563 |
)
|
564 |
attn = attn.to(orig_dtype) # type: ignore
|
565 |
else:
|
|
|
570 |
dropout_p=self.p_dropout,
|
571 |
deterministic=self.deterministic_fa2,
|
572 |
window_size=self.sliding_window,
|
573 |
+
causal=self.is_causal
|
574 |
)
|
575 |
attn = attn.view(bs, dim)
|
576 |
else:
|
577 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
578 |
+
assert False
|
579 |
qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
|
580 |
unpad_bs, seqlen, _ = qkv.shape
|
581 |
|
|
|
614 |
f"heads ({config.num_attention_heads})"
|
615 |
)
|
616 |
|
617 |
+
self.is_causal = config.causal_mask
|
618 |
self.num_attention_heads = config.num_attention_heads
|
619 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
620 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
|
|
699 |
dropout_p=self.p_dropout,
|
700 |
deterministic=self.deterministic_fa2,
|
701 |
window_size=self.sliding_window,
|
702 |
+
causal=self.is_causal
|
703 |
)
|
704 |
attn = attn.to(orig_dtype) # type: ignore
|
705 |
else:
|
|
|
708 |
dropout_p=self.p_dropout,
|
709 |
deterministic=self.deterministic_fa2,
|
710 |
window_size=self.sliding_window,
|
711 |
+
causal=self.is_causal
|
712 |
)
|
713 |
else:
|
714 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
715 |
+
assert False
|
716 |
qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
|
717 |
|
718 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2)
|
|
|
748 |
f"heads ({config.num_attention_heads})"
|
749 |
)
|
750 |
|
751 |
+
self.is_causal = config.causal_mask
|
752 |
self.num_attention_heads = config.num_attention_heads
|
753 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
754 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
|
|
887 |
max_seqlen_q=max_seqlen,
|
888 |
max_seqlen_k=max_seqlen,
|
889 |
deterministic=self.deterministic_fa2,
|
890 |
+
causal=self.is_causal,
|
891 |
)
|
892 |
attn = attn.to(orig_dtype) # type: ignore
|
893 |
else:
|
|
|
901 |
max_seqlen_q=max_seqlen,
|
902 |
max_seqlen_k=max_seqlen,
|
903 |
deterministic=self.deterministic_fa2,
|
904 |
+
causal=self.is_causal,
|
905 |
)
|
906 |
attn = attn.view(bs, dim)
|
907 |
elif self.use_fa2:
|
|
|
919 |
dropout_p=self.p_dropout,
|
920 |
deterministic=self.deterministic_fa2,
|
921 |
window_size=self.sliding_window,
|
922 |
+
causal=self.is_causal,
|
923 |
)
|
924 |
attn = attn.to(orig_dtype) # type: ignore
|
925 |
else:
|
|
|
930 |
dropout_p=self.p_dropout,
|
931 |
deterministic=self.deterministic_fa2,
|
932 |
window_size=self.sliding_window,
|
933 |
+
causal=self.is_causal,
|
934 |
)
|
935 |
attn = attn.view(bs, dim)
|
936 |
else:
|
937 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
938 |
+
assert False
|
939 |
qkv = bert_padding.pad_input(
|
940 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
941 |
) # batch, max_seqlen, thd
|
|
|
975 |
f"heads ({config.num_attention_heads})"
|
976 |
)
|
977 |
|
978 |
+
self.is_causal = config.causal_mask
|
979 |
self.num_attention_heads = config.num_attention_heads
|
980 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
981 |
self.all_head_size = self.num_attention_heads * self.attn_head_size
|
|
|
1086 |
dropout_p=self.p_dropout,
|
1087 |
deterministic=self.deterministic_fa2,
|
1088 |
window_size=self.sliding_window,
|
1089 |
+
causal=self.is_causal,
|
1090 |
)
|
1091 |
attn = attn.to(orig_dtype) # type: ignore
|
1092 |
else:
|
|
|
1095 |
dropout_p=self.p_dropout,
|
1096 |
deterministic=self.deterministic_fa2,
|
1097 |
window_size=self.sliding_window,
|
1098 |
+
causal=self.is_causal
|
1099 |
)
|
1100 |
else:
|
1101 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
1102 |
+
assert False
|
1103 |
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
|
1104 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2)
|
1105 |
attn = F.scaled_dot_product_attention(
|
|
|
1134 |
f"heads ({config.num_attention_heads})"
|
1135 |
)
|
1136 |
|
1137 |
+
self.is_causal = config.causal_mask
|
1138 |
self.num_attention_heads = config.num_attention_heads
|
1139 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
1140 |
self.hidden_size = config.hidden_size
|
|
|
1260 |
dropout_p=self.p_dropout,
|
1261 |
deterministic=self.deterministic_fa2,
|
1262 |
window_size=self.sliding_window,
|
1263 |
+
causal=self.is_causal,
|
1264 |
)
|
1265 |
attn = attn.to(orig_dtype) # type: ignore
|
1266 |
else:
|
|
|
1271 |
dropout_p=self.p_dropout,
|
1272 |
deterministic=self.deterministic_fa2,
|
1273 |
window_size=self.sliding_window,
|
1274 |
+
causal=self.is_causal,
|
1275 |
)
|
1276 |
attn = attn.view(bs, dim)
|
1277 |
else:
|
1278 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
1279 |
+
assert False
|
1280 |
qkv = bert_padding.pad_input(
|
1281 |
qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
|
1282 |
) # batch, max_seqlen, thd
|
|
|
1316 |
f"heads ({config.num_attention_heads})"
|
1317 |
)
|
1318 |
|
1319 |
+
self.is_causal = config.causal_mask
|
1320 |
self.num_attention_heads = config.num_attention_heads
|
1321 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
1322 |
self.hidden_size = config.hidden_size
|
|
|
1421 |
dropout_p=self.p_dropout,
|
1422 |
deterministic=self.deterministic_fa2,
|
1423 |
window_size=self.sliding_window,
|
1424 |
+
causal=self.is_causal
|
1425 |
)
|
1426 |
attn = attn.to(orig_dtype) # type: ignore
|
1427 |
else:
|
|
|
1430 |
dropout_p=self.p_dropout,
|
1431 |
deterministic=self.deterministic_fa2,
|
1432 |
window_size=self.sliding_window,
|
1433 |
+
causal=self.is_causal
|
1434 |
)
|
1435 |
else:
|
1436 |
+
assert not self.is_causal, f"causal mask not implemented here yet"
|
1437 |
+
assert False
|
1438 |
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
|
1439 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2)
|
1440 |
attn = F.scaled_dot_product_attention(
|
|
|
1469 |
f"heads ({config.num_attention_heads})"
|
1470 |
)
|
1471 |
|
1472 |
+
self.is_causal = config.causal_mask
|
1473 |
self.num_attention_heads = config.num_attention_heads
|
1474 |
self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
|
1475 |
self.hidden_size = config.hidden_size
|
|
|
1546 |
dropout_p=self.p_dropout,
|
1547 |
deterministic=self.deterministic_fa2,
|
1548 |
window_size=self.sliding_window,
|
1549 |
+
causal=self.is_causal
|
1550 |
)
|
1551 |
attn = attn.to(orig_dtype) # type: ignore
|
1552 |
else:
|
|
|
1555 |
dropout_p=self.p_dropout,
|
1556 |
deterministic=self.deterministic_fa2,
|
1557 |
window_size=self.sliding_window,
|
1558 |
+
causal=self.is_causal
|
1559 |
)
|
1560 |
else:
|
1561 |
+
assert not self.is_causal, f"causal attention mask not yet implemented here"
|
1562 |
+
assert False
|
1563 |
qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
|
1564 |
q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
|
1565 |
attn = F.scaled_dot_product_attention(
|
config.json
CHANGED
@@ -88,4 +88,4 @@
|
|
88 |
"use_sdpa_attn_mask": false,
|
89 |
"vocab_size": 50368,
|
90 |
"is_casual": true
|
91 |
-
}
|
|
|
88 |
"use_sdpa_attn_mask": false,
|
89 |
"vocab_size": 50368,
|
90 |
"is_casual": true
|
91 |
+
}
|
configuration_bert.py
CHANGED
@@ -97,7 +97,7 @@ class FlexBertConfig(TransformersBertConfig):
|
|
97 |
pad_logits: bool = False,
|
98 |
compile_model: bool = False,
|
99 |
masked_prediction: bool = False,
|
100 |
-
|
101 |
**kwargs,
|
102 |
):
|
103 |
"""
|
@@ -157,7 +157,7 @@ class FlexBertConfig(TransformersBertConfig):
|
|
157 |
pad_logits (bool): Pad logits after the calculating the loss.
|
158 |
compile_model (bool): Compile the subset of the model which can be compiled.
|
159 |
masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers
|
160 |
-
|
161 |
**kwargs: Additional keyword arguments.
|
162 |
"""
|
163 |
super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
|
@@ -215,7 +215,7 @@ class FlexBertConfig(TransformersBertConfig):
|
|
215 |
self.pad_logits = pad_logits
|
216 |
self.compile_model = compile_model
|
217 |
self.masked_prediction = masked_prediction
|
218 |
-
self.
|
219 |
|
220 |
if loss_kwargs.get("return_z_loss", False):
|
221 |
if loss_function != "fa_cross_entropy":
|
|
|
97 |
pad_logits: bool = False,
|
98 |
compile_model: bool = False,
|
99 |
masked_prediction: bool = False,
|
100 |
+
causal_mask: bool = False,
|
101 |
**kwargs,
|
102 |
):
|
103 |
"""
|
|
|
157 |
pad_logits (bool): Pad logits after the calculating the loss.
|
158 |
compile_model (bool): Compile the subset of the model which can be compiled.
|
159 |
masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers
|
160 |
+
causal (bool): Use a causal mask, defaulting to false.
|
161 |
**kwargs: Additional keyword arguments.
|
162 |
"""
|
163 |
super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
|
|
|
215 |
self.pad_logits = pad_logits
|
216 |
self.compile_model = compile_model
|
217 |
self.masked_prediction = masked_prediction
|
218 |
+
self.causal_mask = causal_mask
|
219 |
|
220 |
if loss_kwargs.get("return_z_loss", False):
|
221 |
if loss_function != "fa_cross_entropy":
|
modeling_flexbert.py
CHANGED
@@ -125,7 +125,6 @@ from .rotary import UnpaddedRotaryEmbedding
|
|
125 |
|
126 |
logger = logging.getLogger(__name__)
|
127 |
|
128 |
-
|
129 |
def _count_parameters(model: nn.Module, trainable: bool = True) -> int:
|
130 |
if trainable:
|
131 |
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
@@ -873,7 +872,7 @@ class FlexBertPreTrainedModel(BertPreTrainedModel):
|
|
873 |
|
874 |
def _init_module_weights(self, module: nn.Module):
|
875 |
"""
|
876 |
-
Custom weight init of modules using .bert_layers.initialization.init_weights
|
877 |
Currently only supports init of embedding modules
|
878 |
"""
|
879 |
assert isinstance(module, nn.Module)
|
@@ -1126,7 +1125,6 @@ class FlexBertForMaskedLM(FlexBertPreTrainedModel):
|
|
1126 |
# seqlen) dimensions are flattened
|
1127 |
|
1128 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1129 |
-
|
1130 |
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
1131 |
batch_size, seq_len = input_ids.shape[:2]
|
1132 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
@@ -1506,9 +1504,7 @@ class FlexBertForMultipleChoice(FlexBertPreTrainedModel):
|
|
1506 |
return params
|
1507 |
|
1508 |
|
1509 |
-
class
|
1510 |
-
config_class = FlexBertConfig
|
1511 |
-
|
1512 |
"""Bert Model transformer with a LM head.
|
1513 |
|
1514 |
This head is just a standard LM head module. Used for causal language modeling tasks.
|
@@ -1538,23 +1534,14 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
1538 |
self._init_weights(reset_params=False)
|
1539 |
|
1540 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
1541 |
-
# Handle the XOR condition
|
1542 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
1543 |
-
|
1544 |
-
|
1545 |
-
# Add basic initialization for common module types
|
1546 |
-
if isinstance(module, (nn.Linear, nn.Embedding)):
|
1547 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
1548 |
-
if isinstance(module, nn.Linear) and module.bias is not None:
|
1549 |
-
module.bias.data.zero_()
|
1550 |
-
elif isinstance(module, nn.LayerNorm):
|
1551 |
-
module.bias.data.zero_()
|
1552 |
-
module.weight.data.fill_(1.0)
|
1553 |
else:
|
1554 |
assert isinstance(reset_params, bool)
|
1555 |
self.bert._init_weights(reset_params=reset_params)
|
1556 |
self.lm_head._init_weights(reset_params=reset_params)
|
1557 |
-
|
1558 |
if not self.config.tie_word_embeddings:
|
1559 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
1560 |
|
@@ -1644,7 +1631,6 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
1644 |
# seqlen) dimensions are flattened
|
1645 |
|
1646 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1647 |
-
|
1648 |
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
1649 |
batch_size, seq_len = input_ids.shape[:2]
|
1650 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
@@ -1664,29 +1650,28 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
1664 |
logits = self.compiled_lm_head(hidden_states)
|
1665 |
else:
|
1666 |
logits = self.lm_head(hidden_states)
|
1667 |
-
|
1668 |
loss = None
|
1669 |
if labels is not None:
|
1670 |
-
if
|
1671 |
-
# Unpadded case: shift within each sequence using input_ids
|
1672 |
-
# Initialize shifted labels from input_ids
|
1673 |
shift_labels = torch.full_like(input_ids, -100)
|
1674 |
-
|
1675 |
-
|
|
|
1676 |
for i in range(len(cu_seqlens) - 1):
|
1677 |
-
|
1678 |
-
|
1679 |
-
|
1680 |
-
|
1681 |
-
|
1682 |
-
|
1683 |
-
|
1684 |
-
|
1685 |
-
|
1686 |
-
|
1687 |
-
|
1688 |
-
|
1689 |
-
|
1690 |
|
1691 |
else:
|
1692 |
# Padded case: simple shift
|
@@ -1703,7 +1688,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
1703 |
)
|
1704 |
|
1705 |
if self.pad_logits:
|
1706 |
-
print(f"Padding logits: {logits.shape}")
|
1707 |
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
1708 |
if len(new_logits.shape) == 2:
|
1709 |
new_logits = new_logits.unsqueeze(0)
|
@@ -1714,7 +1699,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
1714 |
attentions=None,
|
1715 |
)
|
1716 |
else:
|
1717 |
-
print(f"Non-padding logits: {logits.shape}")
|
1718 |
if len(logits.shape) == 2:
|
1719 |
logits = logits.unsqueeze(0)
|
1720 |
return CausalLMOutput(
|
@@ -1757,7 +1742,6 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
1757 |
params += _count_parameters(self.lm_head, trainable)
|
1758 |
return params
|
1759 |
|
1760 |
-
FlexBertForCasualLM.register_for_auto_class("AutoModelForCausalLM")
|
1761 |
|
1762 |
def init_model_from_pretrained(
|
1763 |
pretrained_model: FlexBertModel,
|
|
|
125 |
|
126 |
logger = logging.getLogger(__name__)
|
127 |
|
|
|
128 |
def _count_parameters(model: nn.Module, trainable: bool = True) -> int:
|
129 |
if trainable:
|
130 |
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
872 |
|
873 |
def _init_module_weights(self, module: nn.Module):
|
874 |
"""
|
875 |
+
Custom weight init of modules using src.bert_layers.initialization.init_weights
|
876 |
Currently only supports init of embedding modules
|
877 |
"""
|
878 |
assert isinstance(module, nn.Module)
|
|
|
1125 |
# seqlen) dimensions are flattened
|
1126 |
|
1127 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
1128 |
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
1129 |
batch_size, seq_len = input_ids.shape[:2]
|
1130 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
|
|
1504 |
return params
|
1505 |
|
1506 |
|
1507 |
+
class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
|
|
|
1508 |
"""Bert Model transformer with a LM head.
|
1509 |
|
1510 |
This head is just a standard LM head module. Used for causal language modeling tasks.
|
|
|
1534 |
self._init_weights(reset_params=False)
|
1535 |
|
1536 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
|
|
1537 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
1538 |
+
if module:
|
1539 |
+
self._init_module_weights(module)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1540 |
else:
|
1541 |
assert isinstance(reset_params, bool)
|
1542 |
self.bert._init_weights(reset_params=reset_params)
|
1543 |
self.lm_head._init_weights(reset_params=reset_params)
|
1544 |
+
|
1545 |
if not self.config.tie_word_embeddings:
|
1546 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
1547 |
|
|
|
1631 |
# seqlen) dimensions are flattened
|
1632 |
|
1633 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
1634 |
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
1635 |
batch_size, seq_len = input_ids.shape[:2]
|
1636 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
|
|
1650 |
logits = self.compiled_lm_head(hidden_states)
|
1651 |
else:
|
1652 |
logits = self.lm_head(hidden_states)
|
1653 |
+
|
1654 |
loss = None
|
1655 |
if labels is not None:
|
1656 |
+
if cu_seqlens is not None:
|
|
|
|
|
1657 |
shift_labels = torch.full_like(input_ids, -100)
|
1658 |
+
shift_labels[:-1] = input_ids[1:]
|
1659 |
+
|
1660 |
+
# Mask boundaries
|
1661 |
for i in range(len(cu_seqlens) - 1):
|
1662 |
+
boundary_pos = cu_seqlens[i+1] - 1
|
1663 |
+
shift_labels[boundary_pos] = -100
|
1664 |
+
|
1665 |
+
# Mask out PAD tokens
|
1666 |
+
mask = (shift_labels == 50283)
|
1667 |
+
shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
|
1668 |
+
|
1669 |
+
|
1670 |
+
# print input_ids[(cu_seqlens[2]+1)-5:(cu_seqlens[2]+1)+5]
|
1671 |
+
# print shift_labels[(cu_seqlens[2]+1)-5:(cu_seqlens[2]+1)+5]
|
1672 |
+
# print input_ids[(cu_seqlens[-2]+1)-5:(cu_seqlens[-2]+1)+5]
|
1673 |
+
# print shift_labels[(cu_seqlens[-2]+1)-5:(cu_seqlens[-2]+1)+5]
|
1674 |
+
# breakpoint() # pkill -u oweller2 -f wandb
|
1675 |
|
1676 |
else:
|
1677 |
# Padded case: simple shift
|
|
|
1688 |
)
|
1689 |
|
1690 |
if self.pad_logits:
|
1691 |
+
# print(f"Padding logits: {logits.shape}")
|
1692 |
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
1693 |
if len(new_logits.shape) == 2:
|
1694 |
new_logits = new_logits.unsqueeze(0)
|
|
|
1699 |
attentions=None,
|
1700 |
)
|
1701 |
else:
|
1702 |
+
# print(f"Non-padding logits: {logits.shape}")
|
1703 |
if len(logits.shape) == 2:
|
1704 |
logits = logits.unsqueeze(0)
|
1705 |
return CausalLMOutput(
|
|
|
1742 |
params += _count_parameters(self.lm_head, trainable)
|
1743 |
return params
|
1744 |
|
|
|
1745 |
|
1746 |
def init_model_from_pretrained(
|
1747 |
pretrained_model: FlexBertModel,
|