oweller2
commited on
Commit
•
b9219f0
1
Parent(s):
322b01b
try unpadding in inferece
Browse files- modeling_flexbert.py +26 -6
modeling_flexbert.py
CHANGED
@@ -1724,12 +1724,32 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1724 |
if attention_mask is not None:
|
1725 |
attention_mask = attention_mask[:, -1:]
|
1726 |
|
1727 |
-
|
1728 |
-
|
1729 |
-
|
1730 |
-
|
1731 |
-
|
1732 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1733 |
|
1734 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
1735 |
"""Returns the number of parameters in the model.
|
|
|
1724 |
if attention_mask is not None:
|
1725 |
attention_mask = attention_mask[:, -1:]
|
1726 |
|
1727 |
+
# Handle unpadding for the last token if needed
|
1728 |
+
if self.unpad_embeddings:
|
1729 |
+
batch_size, seq_len = input_ids.shape[:2]
|
1730 |
+
if attention_mask is None:
|
1731 |
+
# create all ones, except for padding (TODO?)
|
1732 |
+
attention_mask = torch.ones_like(input_ids)
|
1733 |
+
input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = self.unpad_inputs(
|
1734 |
+
input_ids, attention_mask, None, None
|
1735 |
+
)
|
1736 |
+
return {
|
1737 |
+
"input_ids": input_ids,
|
1738 |
+
"past_key_values": past_key_values,
|
1739 |
+
"use_cache": kwargs.get("use_cache", True),
|
1740 |
+
"attention_mask": None, # FA handles this
|
1741 |
+
"indices": indices,
|
1742 |
+
"cu_seqlens": cu_seqlens,
|
1743 |
+
"max_seqlen": max_seqlen,
|
1744 |
+
"position_ids": position_ids,
|
1745 |
+
}
|
1746 |
+
else:
|
1747 |
+
return {
|
1748 |
+
"input_ids": input_ids,
|
1749 |
+
"past_key_values": past_key_values,
|
1750 |
+
"use_cache": kwargs.get("use_cache", True),
|
1751 |
+
"attention_mask": attention_mask,
|
1752 |
+
}
|
1753 |
|
1754 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
1755 |
"""Returns the number of parameters in the model.
|