Alex Birch commited on
Commit
1e53ac9
1 Parent(s): 9f0a20b

apply gradient checkpointing to Attention blocks

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +2 -2
modeling_mpt.py CHANGED
@@ -12,7 +12,7 @@ from torch.utils.checkpoint import checkpoint
12
  from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
13
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
  from transformers.utils import logging
15
- from .attention import attn_bias_shape, build_attn_bias, PastKeyValue
16
  from .blocks import MPTBlock, MPTBlockOutput
17
  from .norm import NORM_CLASS_REGISTRY
18
  from .configuration_mpt import MPTConfig
@@ -41,7 +41,7 @@ class MPTPreTrainedModel(PreTrainedModel):
41
  _no_split_modules = ['MPTBlock']
42
  supports_gradient_checkpointing = True
43
  def _set_gradient_checkpointing(self, module: nn.Module, value=False) -> None:
44
- if isinstance(module, MPTModel):
45
  module.gradient_checkpointing = value
46
 
47
  class MPTModel(MPTPreTrainedModel):
 
12
  from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
13
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
  from transformers.utils import logging
15
+ from .attention import attn_bias_shape, build_attn_bias, PastKeyValue, MultiheadAttention, MultiQueryAttention
16
  from .blocks import MPTBlock, MPTBlockOutput
17
  from .norm import NORM_CLASS_REGISTRY
18
  from .configuration_mpt import MPTConfig
 
41
  _no_split_modules = ['MPTBlock']
42
  supports_gradient_checkpointing = True
43
  def _set_gradient_checkpointing(self, module: nn.Module, value=False) -> None:
44
+ if isinstance(module, MPTModel) or isinstance(module, MultiheadAttention) or isinstance(module, MultiQueryAttention):
45
  module.gradient_checkpointing = value
46
 
47
  class MPTModel(MPTPreTrainedModel):