oweller2 commited on
Commit
9f10682
1 Parent(s): 4904e15

add unpad back in with created attn_mask

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. modeling_flexbert.py +2 -0
config.json CHANGED
@@ -82,7 +82,7 @@
82
  "sliding_window": 128,
83
  "transformers_version": "4.44.1",
84
  "type_vocab_size": 2,
85
- "unpad_embeddings": false,
86
  "use_cache": true,
87
  "use_fa2": true,
88
  "use_sdpa_attn_mask": false,
 
82
  "sliding_window": 128,
83
  "transformers_version": "4.44.1",
84
  "type_vocab_size": 2,
85
+ "unpad_embeddings": true,
86
  "use_cache": true,
87
  "use_fa2": true,
88
  "use_sdpa_attn_mask": false,
modeling_flexbert.py CHANGED
@@ -1643,6 +1643,8 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1643
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1644
  if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1645
  batch_size, seq_len = input_ids.shape[:2]
 
 
1646
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
1647
  input_ids, attention_mask, position_ids, labels
1648
  )
 
1643
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1644
  if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1645
  batch_size, seq_len = input_ids.shape[:2]
1646
+ if attention_mask is None: # Create causal mask (lower triangular)
1647
+ attention_mask = torch.tril(torch.ones(batch, seqlen), diagonal=0)
1648
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
1649
  input_ids, attention_mask, position_ids, labels
1650
  )