Update modeling_diff_llama.py
Browse files- modeling_diff_llama.py +21 -0
modeling_diff_llama.py
CHANGED
|
@@ -506,5 +506,26 @@ class DiffLLaMAForCausalLM(PreTrainedModel):
|
|
| 506 |
hidden_states=outputs.hidden_states,
|
| 507 |
attentions=outputs.attentions,
|
| 508 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
|
| 510 |
|
|
|
|
| 506 |
hidden_states=outputs.hidden_states,
|
| 507 |
attentions=outputs.attentions,
|
| 508 |
)
|
| 509 |
+
def prepare_inputs_for_generation(
|
| 510 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 511 |
+
):
|
| 512 |
+
if past_key_values:
|
| 513 |
+
input_ids = input_ids[:, -1:]
|
| 514 |
+
|
| 515 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 516 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 517 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 518 |
+
else:
|
| 519 |
+
model_inputs = {"input_ids": input_ids}
|
| 520 |
+
|
| 521 |
+
model_inputs.update(
|
| 522 |
+
{
|
| 523 |
+
"past_key_values": past_key_values,
|
| 524 |
+
"use_cache": kwargs.get("use_cache"),
|
| 525 |
+
"attention_mask": attention_mask,
|
| 526 |
+
"cache_position": kwargs.get("cache_position"),
|
| 527 |
+
}
|
| 528 |
+
)
|
| 529 |
+
return model_inputs
|
| 530 |
|
| 531 |
|