Update recastmlp_llama/modeling_recastmlp_llama.py
Browse files
recastmlp_llama/modeling_recastmlp_llama.py
CHANGED
@@ -281,21 +281,22 @@ class RECASTMLP_llamaModel(PreTrainedModel):
|
|
281 |
|
282 |
if inputs_embeds is None:
|
283 |
inputs_embeds = self.embed_tokens(input_ids)
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
# Create position embeddings to be shared across the decoder layers
|
|
|
286 |
if position_ids is None:
|
287 |
-
|
288 |
-
past_key_values.get_seq_length() if past_key_values is not None else 0
|
289 |
-
)
|
290 |
-
position_ids = torch.arange(
|
291 |
-
past_seen_tokens,
|
292 |
-
past_seen_tokens + inputs_embeds.shape[1],
|
293 |
-
device=inputs_embeds.device,
|
294 |
-
).unsqueeze(0)
|
295 |
-
|
296 |
-
position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
|
297 |
-
hidden_states = inputs_embeds
|
298 |
-
|
299 |
# Get updated causal mask
|
300 |
causal_mask = self._update_causal_mask(
|
301 |
attention_mask,
|
@@ -304,6 +305,9 @@ class RECASTMLP_llamaModel(PreTrainedModel):
|
|
304 |
past_key_values,
|
305 |
output_attentions,
|
306 |
)
|
|
|
|
|
|
|
307 |
|
308 |
# Initialize outputs
|
309 |
all_hidden_states = () if output_hidden_states else None
|
|
|
281 |
|
282 |
if inputs_embeds is None:
|
283 |
inputs_embeds = self.embed_tokens(input_ids)
|
284 |
+
# Set up cache position if not provided
|
285 |
+
if cache_position is None:
|
286 |
+
past_seen_tokens = 0 if past_key_values is None else (
|
287 |
+
past_key_values.get_seq_length() if isinstance(past_key_values, Cache)
|
288 |
+
else past_key_values[0][0].size(-2) if past_key_values
|
289 |
+
else 0
|
290 |
+
)
|
291 |
+
cache_position = torch.arange(
|
292 |
+
past_seen_tokens,
|
293 |
+
past_seen_tokens + inputs_embeds.shape[1],
|
294 |
+
device=inputs_embeds.device
|
295 |
+
)
|
296 |
# Create position embeddings to be shared across the decoder layers
|
297 |
+
# Set up position IDs if not provided
|
298 |
if position_ids is None:
|
299 |
+
position_ids = cache_position.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
# Get updated causal mask
|
301 |
causal_mask = self._update_causal_mask(
|
302 |
attention_mask,
|
|
|
305 |
past_key_values,
|
306 |
output_attentions,
|
307 |
)
|
308 |
+
hidden_states = inputs_embeds
|
309 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
310 |
+
|
311 |
|
312 |
# Initialize outputs
|
313 |
all_hidden_states = () if output_hidden_states else None
|