oweller2 commited on
Commit
81b671b
1 Parent(s): b9219f0

no unpad at inference

Browse files
Files changed (2) hide show
  1. config.json +4 -4
  2. modeling_flexbert.py +6 -26
config.json CHANGED
@@ -69,9 +69,9 @@
69
  "num_attention_heads": 12,
70
  "num_hidden_layers": 22,
71
  "num_initial_layers": 1,
72
- "pad_logits": true,
73
- "pad_token_id": 0,
74
- "padding": "unpadded",
75
  "pooling_type": "cls",
76
  "position_embedding_type": "absolute",
77
  "rotary_emb_base": 10000.0,
@@ -82,7 +82,7 @@
82
  "sliding_window": 128,
83
  "transformers_version": "4.44.1",
84
  "type_vocab_size": 2,
85
- "unpad_embeddings": true,
86
  "use_cache": true,
87
  "use_fa2": true,
88
  "use_sdpa_attn_mask": false,
 
69
  "num_attention_heads": 12,
70
  "num_hidden_layers": 22,
71
  "num_initial_layers": 1,
72
+ "pad_logits": false,
73
+ "pad_token_id": 50283,
74
+ "padding": "padded",
75
  "pooling_type": "cls",
76
  "position_embedding_type": "absolute",
77
  "rotary_emb_base": 10000.0,
 
82
  "sliding_window": 128,
83
  "transformers_version": "4.44.1",
84
  "type_vocab_size": 2,
85
+ "unpad_embeddings": false,
86
  "use_cache": true,
87
  "use_fa2": true,
88
  "use_sdpa_attn_mask": false,
modeling_flexbert.py CHANGED
@@ -1724,32 +1724,12 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1724
  if attention_mask is not None:
1725
  attention_mask = attention_mask[:, -1:]
1726
 
1727
- # Handle unpadding for the last token if needed
1728
- if self.unpad_embeddings:
1729
- batch_size, seq_len = input_ids.shape[:2]
1730
- if attention_mask is None:
1731
- # create all ones, except for padding (TODO?)
1732
- attention_mask = torch.ones_like(input_ids)
1733
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
1734
- input_ids, attention_mask, None, None
1735
- )
1736
- return {
1737
- "input_ids": input_ids,
1738
- "past_key_values": past_key_values,
1739
- "use_cache": kwargs.get("use_cache", True),
1740
- "attention_mask": None, # FA handles this
1741
- "indices": indices,
1742
- "cu_seqlens": cu_seqlens,
1743
- "max_seqlen": max_seqlen,
1744
- "position_ids": position_ids,
1745
- }
1746
- else:
1747
- return {
1748
- "input_ids": input_ids,
1749
- "past_key_values": past_key_values,
1750
- "use_cache": kwargs.get("use_cache", True),
1751
- "attention_mask": attention_mask,
1752
- }
1753
 
1754
  def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1755
  """Returns the number of parameters in the model.
 
1724
  if attention_mask is not None:
1725
  attention_mask = attention_mask[:, -1:]
1726
 
1727
+ return {
1728
+ "input_ids": input_ids,
1729
+ "past_key_values": past_key_values,
1730
+ "use_cache": kwargs.get("use_cache", True),
1731
+ "attention_mask": attention_mask,
1732
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1733
 
1734
  def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1735
  """Returns the number of parameters in the model.