oweller2
commited on
Commit
•
46797c8
1
Parent(s):
81b671b
no pad at inference
Browse files- modeling_flexbert.py +1 -3
modeling_flexbert.py
CHANGED
@@ -1721,14 +1721,12 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1721 |
# only last token for inputs if past is defined
|
1722 |
if past_key_values is not None:
|
1723 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
1724 |
-
if attention_mask is not None:
|
1725 |
-
attention_mask = attention_mask[:, -1:]
|
1726 |
|
1727 |
return {
|
1728 |
"input_ids": input_ids,
|
1729 |
"past_key_values": past_key_values,
|
1730 |
"use_cache": kwargs.get("use_cache", True),
|
1731 |
-
"attention_mask":
|
1732 |
}
|
1733 |
|
1734 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
|
|
1721 |
# only last token for inputs if past is defined
|
1722 |
if past_key_values is not None:
|
1723 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
|
|
|
1724 |
|
1725 |
return {
|
1726 |
"input_ids": input_ids,
|
1727 |
"past_key_values": past_key_values,
|
1728 |
"use_cache": kwargs.get("use_cache", True),
|
1729 |
+
"attention_mask": None,
|
1730 |
}
|
1731 |
|
1732 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|