oweller2 commited on
Commit
e0229bb
1 Parent(s): 1f59624

added updated code:

Browse files
Files changed (5) hide show
  1. __init__.py +3 -3
  2. attention.py +48 -38
  3. config.json +1 -1
  4. configuration_bert.py +3 -3
  5. modeling_flexbert.py +25 -41
__init__.py CHANGED
@@ -33,13 +33,14 @@ from .modeling_flexbert import (
33
  FlexBertForMaskedLM,
34
  FlexBertForSequenceClassification,
35
  FlexBertForMultipleChoice,
36
- FlexBertForCasualLM,
37
  )
38
  from .bert_padding import(
39
  IndexFirstAxis,
40
  IndexPutFirstAxis
41
  )
42
 
 
43
  __all__ = [
44
  "BertAlibiEmbeddings",
45
  "BertAlibiEncoder",
@@ -69,6 +70,5 @@ __all__ = [
69
  "FlexBertForMaskedLM",
70
  "FlexBertForSequenceClassification",
71
  "FlexBertForMultipleChoice",
72
- "IndexFirstAxis",
73
- "IndexPutFirstAxis"
74
  ]
 
33
  FlexBertForMaskedLM,
34
  FlexBertForSequenceClassification,
35
  FlexBertForMultipleChoice,
36
+ FlexBertForCausalLM,
37
  )
38
  from .bert_padding import(
39
  IndexFirstAxis,
40
  IndexPutFirstAxis
41
  )
42
 
43
+
44
  __all__ = [
45
  "BertAlibiEmbeddings",
46
  "BertAlibiEncoder",
 
70
  "FlexBertForMaskedLM",
71
  "FlexBertForSequenceClassification",
72
  "FlexBertForMultipleChoice",
73
+ "FlexBertForCausalLM"
 
74
  ]
attention.py CHANGED
@@ -74,7 +74,7 @@ class BertAlibiUnpadSelfAttention(nn.Module):
74
  f"heads ({config.num_attention_heads})"
75
  )
76
 
77
- self.is_casual = config.casual_mask
78
  self.num_attention_heads = config.num_attention_heads
79
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
80
  self.all_head_size = self.num_attention_heads * self.attention_head_size
@@ -145,7 +145,7 @@ class BertAlibiUnpadSelfAttention(nn.Module):
145
  dropout_p=self.p_dropout,
146
  deterministic=self.deterministic_fa2,
147
  alibi_slopes=slopes,
148
- casual=self.is_casual
149
  )
150
  attention = attention.to(orig_dtype) # type: ignore
151
  else:
@@ -156,10 +156,11 @@ class BertAlibiUnpadSelfAttention(nn.Module):
156
  dropout_p=self.p_dropout,
157
  deterministic=self.deterministic_fa2,
158
  alibi_slopes=slopes,
159
- casual = self.is_casual
160
  )
161
  else:
162
- assert not self.is_casual, f"Casual mask not implemented here yet"
 
163
  qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
164
  unpad_bs, *_ = qkv.shape
165
  qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
@@ -236,6 +237,7 @@ class BertAlibiUnpadAttention(nn.Module):
236
  slopes: None or (batch, heads) or (heads,)
237
  """
238
  assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}"
 
239
  self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
240
  if subset_idx is not None:
241
  return self.output(
@@ -293,7 +295,7 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
293
  f"heads ({config.num_attention_heads})"
294
  )
295
 
296
- self.is_casual = config.casual_mask
297
  self.num_attention_heads = config.num_attention_heads
298
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
299
  self.all_head_size = self.num_attention_heads * self.attn_head_size
@@ -402,7 +404,7 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
402
  dropout_p=self.p_dropout,
403
  deterministic=self.deterministic_fa2,
404
  window_size=self.sliding_window,
405
- casual=self.is_casual
406
  )
407
  attn = attn.to(orig_dtype) # type: ignore
408
  else:
@@ -413,11 +415,12 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
413
  dropout_p=self.p_dropout,
414
  deterministic=self.deterministic_fa2,
415
  window_size=self.sliding_window,
416
- casual=self.is_casual
417
  )
418
  attn = attn.view(bs, dim)
419
  else:
420
- assert not self.is_casual, f"Casual mask not implemented here yet"
 
421
  qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
422
  unpad_bs, seqlen, _ = qkv.shape
423
 
@@ -456,7 +459,7 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
456
  f"heads ({config.num_attention_heads})"
457
  )
458
 
459
- self.is_casual = config.casual_mask
460
  self.num_attention_heads = config.num_attention_heads
461
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
462
  self.hidden_size = config.hidden_size
@@ -556,7 +559,7 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
556
  dropout_p=self.p_dropout,
557
  deterministic=self.deterministic_fa2,
558
  window_size=self.sliding_window,
559
- casual=self.is_casual
560
  )
561
  attn = attn.to(orig_dtype) # type: ignore
562
  else:
@@ -567,11 +570,12 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
567
  dropout_p=self.p_dropout,
568
  deterministic=self.deterministic_fa2,
569
  window_size=self.sliding_window,
570
- casual=self.is_casual
571
  )
572
  attn = attn.view(bs, dim)
573
  else:
574
- assert not self.is_casual, f"Casual mask not implemented here yet"
 
575
  qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
576
  unpad_bs, seqlen, _ = qkv.shape
577
 
@@ -610,7 +614,7 @@ class FlexBertPaddedAttention(FlexBertAttentionBase):
610
  f"heads ({config.num_attention_heads})"
611
  )
612
 
613
- self.is_casual = config.casual_mask
614
  self.num_attention_heads = config.num_attention_heads
615
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
616
  self.all_head_size = self.num_attention_heads * self.attn_head_size
@@ -695,7 +699,7 @@ class FlexBertPaddedAttention(FlexBertAttentionBase):
695
  dropout_p=self.p_dropout,
696
  deterministic=self.deterministic_fa2,
697
  window_size=self.sliding_window,
698
- casual=self.is_casual
699
  )
700
  attn = attn.to(orig_dtype) # type: ignore
701
  else:
@@ -704,10 +708,11 @@ class FlexBertPaddedAttention(FlexBertAttentionBase):
704
  dropout_p=self.p_dropout,
705
  deterministic=self.deterministic_fa2,
706
  window_size=self.sliding_window,
707
- casual=self.is_casual
708
  )
709
  else:
710
- assert not self.is_casual, f"Casual mask not implemented here yet"
 
711
  qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
712
 
713
  q, k, v = qkv.transpose(3, 1).unbind(dim=2)
@@ -743,7 +748,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
743
  f"heads ({config.num_attention_heads})"
744
  )
745
 
746
- self.is_casual = config.casual_mask
747
  self.num_attention_heads = config.num_attention_heads
748
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
749
  self.all_head_size = self.num_attention_heads * self.attn_head_size
@@ -882,7 +887,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
882
  max_seqlen_q=max_seqlen,
883
  max_seqlen_k=max_seqlen,
884
  deterministic=self.deterministic_fa2,
885
- causal=self.is_casual,
886
  )
887
  attn = attn.to(orig_dtype) # type: ignore
888
  else:
@@ -896,7 +901,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
896
  max_seqlen_q=max_seqlen,
897
  max_seqlen_k=max_seqlen,
898
  deterministic=self.deterministic_fa2,
899
- causal=self.is_casual,
900
  )
901
  attn = attn.view(bs, dim)
902
  elif self.use_fa2:
@@ -914,7 +919,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
914
  dropout_p=self.p_dropout,
915
  deterministic=self.deterministic_fa2,
916
  window_size=self.sliding_window,
917
- causal=self.is_casual,
918
  )
919
  attn = attn.to(orig_dtype) # type: ignore
920
  else:
@@ -925,11 +930,12 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
925
  dropout_p=self.p_dropout,
926
  deterministic=self.deterministic_fa2,
927
  window_size=self.sliding_window,
928
- causal=self.is_casual,
929
  )
930
  attn = attn.view(bs, dim)
931
  else:
932
- assert not self.is_casual, f"Casual mask not implemented here yet"
 
933
  qkv = bert_padding.pad_input(
934
  qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
935
  ) # batch, max_seqlen, thd
@@ -969,7 +975,7 @@ class FlexBertPaddedRopeAttention(FlexBertAttentionBase):
969
  f"heads ({config.num_attention_heads})"
970
  )
971
 
972
- self.is_casual = config.casual_mask
973
  self.num_attention_heads = config.num_attention_heads
974
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
975
  self.all_head_size = self.num_attention_heads * self.attn_head_size
@@ -1080,7 +1086,7 @@ class FlexBertPaddedRopeAttention(FlexBertAttentionBase):
1080
  dropout_p=self.p_dropout,
1081
  deterministic=self.deterministic_fa2,
1082
  window_size=self.sliding_window,
1083
- casual=self.is_casual,
1084
  )
1085
  attn = attn.to(orig_dtype) # type: ignore
1086
  else:
@@ -1089,10 +1095,11 @@ class FlexBertPaddedRopeAttention(FlexBertAttentionBase):
1089
  dropout_p=self.p_dropout,
1090
  deterministic=self.deterministic_fa2,
1091
  window_size=self.sliding_window,
1092
- casual=self.is_casual
1093
  )
1094
  else:
1095
- assert not self.is_casual, f"Casual mask not implemented here yet"
 
1096
  qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
1097
  q, k, v = qkv.transpose(3, 1).unbind(dim=2)
1098
  attn = F.scaled_dot_product_attention(
@@ -1127,7 +1134,7 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
1127
  f"heads ({config.num_attention_heads})"
1128
  )
1129
 
1130
- self.is_casual = config.casual_mask
1131
  self.num_attention_heads = config.num_attention_heads
1132
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
1133
  self.hidden_size = config.hidden_size
@@ -1253,7 +1260,7 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
1253
  dropout_p=self.p_dropout,
1254
  deterministic=self.deterministic_fa2,
1255
  window_size=self.sliding_window,
1256
- casual=self.is_casual,
1257
  )
1258
  attn = attn.to(orig_dtype) # type: ignore
1259
  else:
@@ -1264,11 +1271,12 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
1264
  dropout_p=self.p_dropout,
1265
  deterministic=self.deterministic_fa2,
1266
  window_size=self.sliding_window,
1267
- casual=self.is_casual,
1268
  )
1269
  attn = attn.view(bs, dim)
1270
  else:
1271
- assert not self.is_casual, f"Casual mask not implemented here yet"
 
1272
  qkv = bert_padding.pad_input(
1273
  qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
1274
  ) # batch, max_seqlen, thd
@@ -1308,7 +1316,7 @@ class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase):
1308
  f"heads ({config.num_attention_heads})"
1309
  )
1310
 
1311
- self.is_casual = config.casual_mask
1312
  self.num_attention_heads = config.num_attention_heads
1313
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
1314
  self.hidden_size = config.hidden_size
@@ -1413,7 +1421,7 @@ class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase):
1413
  dropout_p=self.p_dropout,
1414
  deterministic=self.deterministic_fa2,
1415
  window_size=self.sliding_window,
1416
- casual=self.is_casual
1417
  )
1418
  attn = attn.to(orig_dtype) # type: ignore
1419
  else:
@@ -1422,10 +1430,11 @@ class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase):
1422
  dropout_p=self.p_dropout,
1423
  deterministic=self.deterministic_fa2,
1424
  window_size=self.sliding_window,
1425
- casual=self.is_casual
1426
  )
1427
  else:
1428
- assert not self.is_casual, f"Casual mask not implemented here yet"
 
1429
  qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
1430
  q, k, v = qkv.transpose(3, 1).unbind(dim=2)
1431
  attn = F.scaled_dot_product_attention(
@@ -1460,7 +1469,7 @@ class FlexBertPaddedParallelAttention(FlexBertAttentionBase):
1460
  f"heads ({config.num_attention_heads})"
1461
  )
1462
 
1463
- self.is_casual = config.casual_mask
1464
  self.num_attention_heads = config.num_attention_heads
1465
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
1466
  self.hidden_size = config.hidden_size
@@ -1537,7 +1546,7 @@ class FlexBertPaddedParallelAttention(FlexBertAttentionBase):
1537
  dropout_p=self.p_dropout,
1538
  deterministic=self.deterministic_fa2,
1539
  window_size=self.sliding_window,
1540
- casual=self.is_casual
1541
  )
1542
  attn = attn.to(orig_dtype) # type: ignore
1543
  else:
@@ -1546,10 +1555,11 @@ class FlexBertPaddedParallelAttention(FlexBertAttentionBase):
1546
  dropout_p=self.p_dropout,
1547
  deterministic=self.deterministic_fa2,
1548
  window_size=self.sliding_window,
1549
- casual=self.is_casual
1550
  )
1551
  else:
1552
- assert not self.is_casual, f"Casual attention mask not yet implemented here"
 
1553
  qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
1554
  q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
1555
  attn = F.scaled_dot_product_attention(
 
74
  f"heads ({config.num_attention_heads})"
75
  )
76
 
77
+ self.is_causal = config.causal_mask
78
  self.num_attention_heads = config.num_attention_heads
79
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
80
  self.all_head_size = self.num_attention_heads * self.attention_head_size
 
145
  dropout_p=self.p_dropout,
146
  deterministic=self.deterministic_fa2,
147
  alibi_slopes=slopes,
148
+ causal=self.is_causal
149
  )
150
  attention = attention.to(orig_dtype) # type: ignore
151
  else:
 
156
  dropout_p=self.p_dropout,
157
  deterministic=self.deterministic_fa2,
158
  alibi_slopes=slopes,
159
+ causal = self.is_causal
160
  )
161
  else:
162
+ assert not self.is_causal, f"causal mask not implemented here yet"
163
+ assert False
164
  qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
165
  unpad_bs, *_ = qkv.shape
166
  qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
 
237
  slopes: None or (batch, heads) or (heads,)
238
  """
239
  assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}"
240
+ assert False
241
  self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
242
  if subset_idx is not None:
243
  return self.output(
 
295
  f"heads ({config.num_attention_heads})"
296
  )
297
 
298
+ self.is_causal = config.causal_mask
299
  self.num_attention_heads = config.num_attention_heads
300
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
301
  self.all_head_size = self.num_attention_heads * self.attn_head_size
 
404
  dropout_p=self.p_dropout,
405
  deterministic=self.deterministic_fa2,
406
  window_size=self.sliding_window,
407
+ causal=self.is_causal
408
  )
409
  attn = attn.to(orig_dtype) # type: ignore
410
  else:
 
415
  dropout_p=self.p_dropout,
416
  deterministic=self.deterministic_fa2,
417
  window_size=self.sliding_window,
418
+ causal=self.is_causal
419
  )
420
  attn = attn.view(bs, dim)
421
  else:
422
+ assert not self.is_causal, f"causal mask not implemented here yet"
423
+ assert False
424
  qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
425
  unpad_bs, seqlen, _ = qkv.shape
426
 
 
459
  f"heads ({config.num_attention_heads})"
460
  )
461
 
462
+ self.is_causal = config.causal_mask
463
  self.num_attention_heads = config.num_attention_heads
464
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
465
  self.hidden_size = config.hidden_size
 
559
  dropout_p=self.p_dropout,
560
  deterministic=self.deterministic_fa2,
561
  window_size=self.sliding_window,
562
+ causal=self.is_causal
563
  )
564
  attn = attn.to(orig_dtype) # type: ignore
565
  else:
 
570
  dropout_p=self.p_dropout,
571
  deterministic=self.deterministic_fa2,
572
  window_size=self.sliding_window,
573
+ causal=self.is_causal
574
  )
575
  attn = attn.view(bs, dim)
576
  else:
577
+ assert not self.is_causal, f"causal mask not implemented here yet"
578
+ assert False
579
  qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
580
  unpad_bs, seqlen, _ = qkv.shape
581
 
 
614
  f"heads ({config.num_attention_heads})"
615
  )
616
 
617
+ self.is_causal = config.causal_mask
618
  self.num_attention_heads = config.num_attention_heads
619
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
620
  self.all_head_size = self.num_attention_heads * self.attn_head_size
 
699
  dropout_p=self.p_dropout,
700
  deterministic=self.deterministic_fa2,
701
  window_size=self.sliding_window,
702
+ causal=self.is_causal
703
  )
704
  attn = attn.to(orig_dtype) # type: ignore
705
  else:
 
708
  dropout_p=self.p_dropout,
709
  deterministic=self.deterministic_fa2,
710
  window_size=self.sliding_window,
711
+ causal=self.is_causal
712
  )
713
  else:
714
+ assert not self.is_causal, f"causal mask not implemented here yet"
715
+ assert False
716
  qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
717
 
718
  q, k, v = qkv.transpose(3, 1).unbind(dim=2)
 
748
  f"heads ({config.num_attention_heads})"
749
  )
750
 
751
+ self.is_causal = config.causal_mask
752
  self.num_attention_heads = config.num_attention_heads
753
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
754
  self.all_head_size = self.num_attention_heads * self.attn_head_size
 
887
  max_seqlen_q=max_seqlen,
888
  max_seqlen_k=max_seqlen,
889
  deterministic=self.deterministic_fa2,
890
+ causal=self.is_causal,
891
  )
892
  attn = attn.to(orig_dtype) # type: ignore
893
  else:
 
901
  max_seqlen_q=max_seqlen,
902
  max_seqlen_k=max_seqlen,
903
  deterministic=self.deterministic_fa2,
904
+ causal=self.is_causal,
905
  )
906
  attn = attn.view(bs, dim)
907
  elif self.use_fa2:
 
919
  dropout_p=self.p_dropout,
920
  deterministic=self.deterministic_fa2,
921
  window_size=self.sliding_window,
922
+ causal=self.is_causal,
923
  )
924
  attn = attn.to(orig_dtype) # type: ignore
925
  else:
 
930
  dropout_p=self.p_dropout,
931
  deterministic=self.deterministic_fa2,
932
  window_size=self.sliding_window,
933
+ causal=self.is_causal,
934
  )
935
  attn = attn.view(bs, dim)
936
  else:
937
+ assert not self.is_causal, f"causal mask not implemented here yet"
938
+ assert False
939
  qkv = bert_padding.pad_input(
940
  qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
941
  ) # batch, max_seqlen, thd
 
975
  f"heads ({config.num_attention_heads})"
976
  )
977
 
978
+ self.is_causal = config.causal_mask
979
  self.num_attention_heads = config.num_attention_heads
980
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
981
  self.all_head_size = self.num_attention_heads * self.attn_head_size
 
1086
  dropout_p=self.p_dropout,
1087
  deterministic=self.deterministic_fa2,
1088
  window_size=self.sliding_window,
1089
+ causal=self.is_causal,
1090
  )
1091
  attn = attn.to(orig_dtype) # type: ignore
1092
  else:
 
1095
  dropout_p=self.p_dropout,
1096
  deterministic=self.deterministic_fa2,
1097
  window_size=self.sliding_window,
1098
+ causal=self.is_causal
1099
  )
1100
  else:
1101
+ assert not self.is_causal, f"causal mask not implemented here yet"
1102
+ assert False
1103
  qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
1104
  q, k, v = qkv.transpose(3, 1).unbind(dim=2)
1105
  attn = F.scaled_dot_product_attention(
 
1134
  f"heads ({config.num_attention_heads})"
1135
  )
1136
 
1137
+ self.is_causal = config.causal_mask
1138
  self.num_attention_heads = config.num_attention_heads
1139
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
1140
  self.hidden_size = config.hidden_size
 
1260
  dropout_p=self.p_dropout,
1261
  deterministic=self.deterministic_fa2,
1262
  window_size=self.sliding_window,
1263
+ causal=self.is_causal,
1264
  )
1265
  attn = attn.to(orig_dtype) # type: ignore
1266
  else:
 
1271
  dropout_p=self.p_dropout,
1272
  deterministic=self.deterministic_fa2,
1273
  window_size=self.sliding_window,
1274
+ causal=self.is_causal,
1275
  )
1276
  attn = attn.view(bs, dim)
1277
  else:
1278
+ assert not self.is_causal, f"causal mask not implemented here yet"
1279
+ assert False
1280
  qkv = bert_padding.pad_input(
1281
  qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
1282
  ) # batch, max_seqlen, thd
 
1316
  f"heads ({config.num_attention_heads})"
1317
  )
1318
 
1319
+ self.is_causal = config.causal_mask
1320
  self.num_attention_heads = config.num_attention_heads
1321
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
1322
  self.hidden_size = config.hidden_size
 
1421
  dropout_p=self.p_dropout,
1422
  deterministic=self.deterministic_fa2,
1423
  window_size=self.sliding_window,
1424
+ causal=self.is_causal
1425
  )
1426
  attn = attn.to(orig_dtype) # type: ignore
1427
  else:
 
1430
  dropout_p=self.p_dropout,
1431
  deterministic=self.deterministic_fa2,
1432
  window_size=self.sliding_window,
1433
+ causal=self.is_causal
1434
  )
1435
  else:
1436
+ assert not self.is_causal, f"causal mask not implemented here yet"
1437
+ assert False
1438
  qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
1439
  q, k, v = qkv.transpose(3, 1).unbind(dim=2)
1440
  attn = F.scaled_dot_product_attention(
 
1469
  f"heads ({config.num_attention_heads})"
1470
  )
1471
 
1472
+ self.is_causal = config.causal_mask
1473
  self.num_attention_heads = config.num_attention_heads
1474
  self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
1475
  self.hidden_size = config.hidden_size
 
1546
  dropout_p=self.p_dropout,
1547
  deterministic=self.deterministic_fa2,
1548
  window_size=self.sliding_window,
1549
+ causal=self.is_causal
1550
  )
1551
  attn = attn.to(orig_dtype) # type: ignore
1552
  else:
 
1555
  dropout_p=self.p_dropout,
1556
  deterministic=self.deterministic_fa2,
1557
  window_size=self.sliding_window,
1558
+ causal=self.is_causal
1559
  )
1560
  else:
1561
+ assert not self.is_causal, f"causal attention mask not yet implemented here"
1562
+ assert False
1563
  qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
1564
  q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
1565
  attn = F.scaled_dot_product_attention(
config.json CHANGED
@@ -88,4 +88,4 @@
88
  "use_sdpa_attn_mask": false,
89
  "vocab_size": 50368,
90
  "is_casual": true
91
- }
 
88
  "use_sdpa_attn_mask": false,
89
  "vocab_size": 50368,
90
  "is_casual": true
91
+ }
configuration_bert.py CHANGED
@@ -97,7 +97,7 @@ class FlexBertConfig(TransformersBertConfig):
97
  pad_logits: bool = False,
98
  compile_model: bool = False,
99
  masked_prediction: bool = False,
100
- casual_mask: bool = False,
101
  **kwargs,
102
  ):
103
  """
@@ -157,7 +157,7 @@ class FlexBertConfig(TransformersBertConfig):
157
  pad_logits (bool): Pad logits after the calculating the loss.
158
  compile_model (bool): Compile the subset of the model which can be compiled.
159
  masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers
160
- casual_mask (bool): Use a casual mask, defaulting to false.
161
  **kwargs: Additional keyword arguments.
162
  """
163
  super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
@@ -215,7 +215,7 @@ class FlexBertConfig(TransformersBertConfig):
215
  self.pad_logits = pad_logits
216
  self.compile_model = compile_model
217
  self.masked_prediction = masked_prediction
218
- self.casual_mask = casual_mask
219
 
220
  if loss_kwargs.get("return_z_loss", False):
221
  if loss_function != "fa_cross_entropy":
 
97
  pad_logits: bool = False,
98
  compile_model: bool = False,
99
  masked_prediction: bool = False,
100
+ causal_mask: bool = False,
101
  **kwargs,
102
  ):
103
  """
 
157
  pad_logits (bool): Pad logits after the calculating the loss.
158
  compile_model (bool): Compile the subset of the model which can be compiled.
159
  masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers
160
+ causal (bool): Use a causal mask, defaulting to false.
161
  **kwargs: Additional keyword arguments.
162
  """
163
  super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
 
215
  self.pad_logits = pad_logits
216
  self.compile_model = compile_model
217
  self.masked_prediction = masked_prediction
218
+ self.causal_mask = causal_mask
219
 
220
  if loss_kwargs.get("return_z_loss", False):
221
  if loss_function != "fa_cross_entropy":
modeling_flexbert.py CHANGED
@@ -125,7 +125,6 @@ from .rotary import UnpaddedRotaryEmbedding
125
 
126
  logger = logging.getLogger(__name__)
127
 
128
-
129
  def _count_parameters(model: nn.Module, trainable: bool = True) -> int:
130
  if trainable:
131
  return sum(p.numel() for p in model.parameters() if p.requires_grad)
@@ -873,7 +872,7 @@ class FlexBertPreTrainedModel(BertPreTrainedModel):
873
 
874
  def _init_module_weights(self, module: nn.Module):
875
  """
876
- Custom weight init of modules using .bert_layers.initialization.init_weights
877
  Currently only supports init of embedding modules
878
  """
879
  assert isinstance(module, nn.Module)
@@ -1126,7 +1125,6 @@ class FlexBertForMaskedLM(FlexBertPreTrainedModel):
1126
  # seqlen) dimensions are flattened
1127
 
1128
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1129
-
1130
  if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1131
  batch_size, seq_len = input_ids.shape[:2]
1132
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
@@ -1506,9 +1504,7 @@ class FlexBertForMultipleChoice(FlexBertPreTrainedModel):
1506
  return params
1507
 
1508
 
1509
- class FlexBertForCasualLM(FlexBertPreTrainedModel):
1510
- config_class = FlexBertConfig
1511
-
1512
  """Bert Model transformer with a LM head.
1513
 
1514
  This head is just a standard LM head module. Used for causal language modeling tasks.
@@ -1538,23 +1534,14 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1538
  self._init_weights(reset_params=False)
1539
 
1540
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1541
- # Handle the XOR condition
1542
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1543
-
1544
- if module is not None:
1545
- # Add basic initialization for common module types
1546
- if isinstance(module, (nn.Linear, nn.Embedding)):
1547
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1548
- if isinstance(module, nn.Linear) and module.bias is not None:
1549
- module.bias.data.zero_()
1550
- elif isinstance(module, nn.LayerNorm):
1551
- module.bias.data.zero_()
1552
- module.weight.data.fill_(1.0)
1553
  else:
1554
  assert isinstance(reset_params, bool)
1555
  self.bert._init_weights(reset_params=reset_params)
1556
  self.lm_head._init_weights(reset_params=reset_params)
1557
-
1558
  if not self.config.tie_word_embeddings:
1559
  init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1560
 
@@ -1644,7 +1631,6 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1644
  # seqlen) dimensions are flattened
1645
 
1646
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1647
-
1648
  if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1649
  batch_size, seq_len = input_ids.shape[:2]
1650
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
@@ -1664,29 +1650,28 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1664
  logits = self.compiled_lm_head(hidden_states)
1665
  else:
1666
  logits = self.lm_head(hidden_states)
1667
-
1668
  loss = None
1669
  if labels is not None:
1670
- if indices is not None:
1671
- # Unpadded case: shift within each sequence using input_ids
1672
- # Initialize shifted labels from input_ids
1673
  shift_labels = torch.full_like(input_ids, -100)
1674
-
1675
- # For each sequence, shift the input_ids to create labels
 
1676
  for i in range(len(cu_seqlens) - 1):
1677
- start = cu_seqlens[i]
1678
- end = cu_seqlens[i + 1]
1679
- # Input: [A, B, C, D] -> Labels: [B, C, D, -100]
1680
- shift_labels[start:end-1] = input_ids[start+1:end]
1681
-
1682
- # Debug prints
1683
- # print(f"input_ids slice: {input_ids[:20]}") # Show first 20 tokens
1684
- # print(f"shift_labels slice: {shift_labels[:20]}") # Show first 20 token
1685
-
1686
- # # Debug prints
1687
- # print(f"input_ids slice: {input_ids[:20]}") # Show first 20 tokens
1688
- # print(f"shift_labels slice: {shift_labels[:20]}") # Show first 20 tokens
1689
- # print(f"First sequence length: {cu_seqlens[1] - cu_seqlens[0]}")
1690
 
1691
  else:
1692
  # Padded case: simple shift
@@ -1703,7 +1688,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1703
  )
1704
 
1705
  if self.pad_logits:
1706
- print(f"Padding logits: {logits.shape}")
1707
  new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
1708
  if len(new_logits.shape) == 2:
1709
  new_logits = new_logits.unsqueeze(0)
@@ -1714,7 +1699,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1714
  attentions=None,
1715
  )
1716
  else:
1717
- print(f"Non-padding logits: {logits.shape}")
1718
  if len(logits.shape) == 2:
1719
  logits = logits.unsqueeze(0)
1720
  return CausalLMOutput(
@@ -1757,7 +1742,6 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1757
  params += _count_parameters(self.lm_head, trainable)
1758
  return params
1759
 
1760
- FlexBertForCasualLM.register_for_auto_class("AutoModelForCausalLM")
1761
 
1762
  def init_model_from_pretrained(
1763
  pretrained_model: FlexBertModel,
 
125
 
126
  logger = logging.getLogger(__name__)
127
 
 
128
  def _count_parameters(model: nn.Module, trainable: bool = True) -> int:
129
  if trainable:
130
  return sum(p.numel() for p in model.parameters() if p.requires_grad)
 
872
 
873
  def _init_module_weights(self, module: nn.Module):
874
  """
875
+ Custom weight init of modules using src.bert_layers.initialization.init_weights
876
  Currently only supports init of embedding modules
877
  """
878
  assert isinstance(module, nn.Module)
 
1125
  # seqlen) dimensions are flattened
1126
 
1127
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
1128
  if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1129
  batch_size, seq_len = input_ids.shape[:2]
1130
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
 
1504
  return params
1505
 
1506
 
1507
+ class FlexBertForCausalLM(FlexBertPreTrainedModel):
 
 
1508
  """Bert Model transformer with a LM head.
1509
 
1510
  This head is just a standard LM head module. Used for causal language modeling tasks.
 
1534
  self._init_weights(reset_params=False)
1535
 
1536
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
 
1537
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1538
+ if module:
1539
+ self._init_module_weights(module)
 
 
 
 
 
 
 
 
1540
  else:
1541
  assert isinstance(reset_params, bool)
1542
  self.bert._init_weights(reset_params=reset_params)
1543
  self.lm_head._init_weights(reset_params=reset_params)
1544
+
1545
  if not self.config.tie_word_embeddings:
1546
  init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1547
 
 
1631
  # seqlen) dimensions are flattened
1632
 
1633
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
1634
  if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1635
  batch_size, seq_len = input_ids.shape[:2]
1636
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
 
1650
  logits = self.compiled_lm_head(hidden_states)
1651
  else:
1652
  logits = self.lm_head(hidden_states)
1653
+
1654
  loss = None
1655
  if labels is not None:
1656
+ if cu_seqlens is not None:
 
 
1657
  shift_labels = torch.full_like(input_ids, -100)
1658
+ shift_labels[:-1] = input_ids[1:]
1659
+
1660
+ # Mask boundaries
1661
  for i in range(len(cu_seqlens) - 1):
1662
+ boundary_pos = cu_seqlens[i+1] - 1
1663
+ shift_labels[boundary_pos] = -100
1664
+
1665
+ # Mask out PAD tokens
1666
+ mask = (shift_labels == 50283)
1667
+ shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
1668
+
1669
+
1670
+ # print input_ids[(cu_seqlens[2]+1)-5:(cu_seqlens[2]+1)+5]
1671
+ # print shift_labels[(cu_seqlens[2]+1)-5:(cu_seqlens[2]+1)+5]
1672
+ # print input_ids[(cu_seqlens[-2]+1)-5:(cu_seqlens[-2]+1)+5]
1673
+ # print shift_labels[(cu_seqlens[-2]+1)-5:(cu_seqlens[-2]+1)+5]
1674
+ # breakpoint() # pkill -u oweller2 -f wandb
1675
 
1676
  else:
1677
  # Padded case: simple shift
 
1688
  )
1689
 
1690
  if self.pad_logits:
1691
+ # print(f"Padding logits: {logits.shape}")
1692
  new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
1693
  if len(new_logits.shape) == 2:
1694
  new_logits = new_logits.unsqueeze(0)
 
1699
  attentions=None,
1700
  )
1701
  else:
1702
+ # print(f"Non-padding logits: {logits.shape}")
1703
  if len(logits.shape) == 2:
1704
  logits = logits.unsqueeze(0)
1705
  return CausalLMOutput(
 
1742
  params += _count_parameters(self.lm_head, trainable)
1743
  return params
1744
 
 
1745
 
1746
  def init_model_from_pretrained(
1747
  pretrained_model: FlexBertModel,