Fix tuple change from reapply
Browse files- modeling_mpt.py +6 -6
modeling_mpt.py
CHANGED
@@ -248,7 +248,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
248 |
|
249 |
return custom_forward
|
250 |
|
251 |
-
(x,
|
252 |
create_custom_forward(block),
|
253 |
x,
|
254 |
past_key_value,
|
@@ -260,11 +260,11 @@ class MPTModel(MPTPreTrainedModel):
|
|
260 |
past_key_values[b_idx] = past_key_value
|
261 |
else:
|
262 |
(x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
|
269 |
|
270 |
x = self.norm_f(x)
|
|
|
248 |
|
249 |
return custom_forward
|
250 |
|
251 |
+
(x, attn_weights, present) = torch.utils.checkpoint.checkpoint(
|
252 |
create_custom_forward(block),
|
253 |
x,
|
254 |
past_key_value,
|
|
|
260 |
past_key_values[b_idx] = past_key_value
|
261 |
else:
|
262 |
(x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
|
263 |
+
if presents is not None:
|
264 |
+
presents += (present,)
|
265 |
+
if output_attentions:
|
266 |
+
assert all_self_attns is not None
|
267 |
+
all_self_attns = all_self_attns + (attn_weights,)
|
268 |
|
269 |
|
270 |
x = self.norm_f(x)
|