appledora commited on
Commit
ab674e7
1 Parent(s): 7a1d06b

Update modeling_recastmlp_llama.py

Browse files
Files changed (1) hide show
  1. modeling_recastmlp_llama.py +17 -13
modeling_recastmlp_llama.py CHANGED
@@ -281,21 +281,22 @@ class RECASTMLP_llamaModel(PreTrainedModel):
281
 
282
  if inputs_embeds is None:
283
  inputs_embeds = self.embed_tokens(input_ids)
284
-
 
 
 
 
 
 
 
 
 
 
 
285
  # Create position embeddings to be shared across the decoder layers
 
286
  if position_ids is None:
287
- past_seen_tokens = (
288
- past_key_values.get_seq_length() if past_key_values is not None else 0
289
- )
290
- position_ids = torch.arange(
291
- past_seen_tokens,
292
- past_seen_tokens + inputs_embeds.shape[1],
293
- device=inputs_embeds.device,
294
- ).unsqueeze(0)
295
-
296
- position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
297
- hidden_states = inputs_embeds
298
-
299
  # Get updated causal mask
300
  causal_mask = self._update_causal_mask(
301
  attention_mask,
@@ -304,6 +305,9 @@ class RECASTMLP_llamaModel(PreTrainedModel):
304
  past_key_values,
305
  output_attentions,
306
  )
 
 
 
307
 
308
  # Initialize outputs
309
  all_hidden_states = () if output_hidden_states else None
 
281
 
282
  if inputs_embeds is None:
283
  inputs_embeds = self.embed_tokens(input_ids)
284
+ # Set up cache position if not provided
285
+ if cache_position is None:
286
+ past_seen_tokens = 0 if past_key_values is None else (
287
+ past_key_values.get_seq_length() if isinstance(past_key_values, Cache)
288
+ else past_key_values[0][0].size(-2) if past_key_values
289
+ else 0
290
+ )
291
+ cache_position = torch.arange(
292
+ past_seen_tokens,
293
+ past_seen_tokens + inputs_embeds.shape[1],
294
+ device=inputs_embeds.device
295
+ )
296
  # Create position embeddings to be shared across the decoder layers
297
+ # Set up position IDs if not provided
298
  if position_ids is None:
299
+ position_ids = cache_position.unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
300
  # Get updated causal mask
301
  causal_mask = self._update_causal_mask(
302
  attention_mask,
 
305
  past_key_values,
306
  output_attentions,
307
  )
308
+ hidden_states = inputs_embeds
309
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
310
+
311
 
312
  # Initialize outputs
313
  all_hidden_states = () if output_hidden_states else None