Fix tuple change from reapply

#5
Files changed (1) hide show
  1. modeling_mpt.py +6 -6
modeling_mpt.py CHANGED
@@ -248,7 +248,7 @@ class MPTModel(MPTPreTrainedModel):
248
 
249
  return custom_forward
250
 
251
- (x, past_key_value) = torch.utils.checkpoint.checkpoint(
252
  create_custom_forward(block),
253
  x,
254
  past_key_value,
@@ -260,11 +260,11 @@ class MPTModel(MPTPreTrainedModel):
260
  past_key_values[b_idx] = past_key_value
261
  else:
262
  (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
263
- if presents is not None:
264
- presents += (present,)
265
- if output_attentions:
266
- assert all_self_attns is not None
267
- all_self_attns = all_self_attns + (attn_weights,)
268
 
269
 
270
  x = self.norm_f(x)
 
248
 
249
  return custom_forward
250
 
251
+ (x, attn_weights, present) = torch.utils.checkpoint.checkpoint(
252
  create_custom_forward(block),
253
  x,
254
  past_key_value,
 
260
  past_key_values[b_idx] = past_key_value
261
  else:
262
  (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
263
+ if presents is not None:
264
+ presents += (present,)
265
+ if output_attentions:
266
+ assert all_self_attns is not None
267
+ all_self_attns = all_self_attns + (attn_weights,)
268
 
269
 
270
  x = self.norm_f(x)