oweller2
commited on
Commit
•
1f61dbc
1
Parent(s):
3cd62e1
update
Browse files- modeling_flexbert.py +19 -22
modeling_flexbert.py
CHANGED
@@ -1708,28 +1708,25 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1708 |
attentions=None,
|
1709 |
)
|
1710 |
|
1711 |
-
def prepare_inputs_for_generation(
|
1712 |
-
|
1713 |
-
|
1714 |
-
|
1715 |
-
|
1716 |
-
|
1717 |
-
|
1718 |
-
|
1719 |
-
|
1720 |
-
|
1721 |
-
|
1722 |
-
|
1723 |
-
|
1724 |
-
|
1725 |
-
|
1726 |
-
|
1727 |
-
|
1728 |
-
|
1729 |
-
|
1730 |
-
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
1731 |
-
|
1732 |
-
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
1733 |
|
1734 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
1735 |
"""Returns the number of parameters in the model.
|
|
|
1708 |
attentions=None,
|
1709 |
)
|
1710 |
|
1711 |
+
def prepare_inputs_for_generation(
|
1712 |
+
self,
|
1713 |
+
input_ids: torch.Tensor,
|
1714 |
+
past_key_values: Optional[torch.FloatTensor] = None,
|
1715 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1716 |
+
**kwargs
|
1717 |
+
) -> dict:
|
1718 |
+
# only last token for inputs if past is defined
|
1719 |
+
if past_key_values is not None:
|
1720 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
1721 |
+
if attention_mask is not None:
|
1722 |
+
attention_mask = attention_mask[:, -1:]
|
1723 |
+
|
1724 |
+
return {
|
1725 |
+
"input_ids": input_ids,
|
1726 |
+
"past_key_values": past_key_values,
|
1727 |
+
"use_cache": kwargs.get("use_cache", True),
|
1728 |
+
"attention_mask": attention_mask,
|
1729 |
+
}
|
|
|
|
|
|
|
1730 |
|
1731 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
1732 |
"""Returns the number of parameters in the model.
|