compatible with transformers>=4.42.0
Browse files- modeling_kangaroo.py +9 -1
modeling_kangaroo.py
CHANGED
@@ -1346,7 +1346,15 @@ class KangarooForCausalLM(LlamaPreTrainedModel):
|
|
1346 |
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1347 |
|
1348 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1350 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
1351 |
else:
|
1352 |
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
|
|
1346 |
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1347 |
|
1348 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1349 |
+
set_inputs_embeds = False
|
1350 |
+
if inputs_embeds is not None:
|
1351 |
+
if isinstance(past_key_values, Cache):
|
1352 |
+
if past_key_values.get_seq_length() == 0:
|
1353 |
+
set_inputs_embeds = True
|
1354 |
+
else:
|
1355 |
+
if past_key_values is None:
|
1356 |
+
set_inputs_embeds = True
|
1357 |
+
if set_inputs_embeds:
|
1358 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
1359 |
else:
|
1360 |
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|