feat: support gradient checkpointing
Browse files- modeling_bert.py +16 -0
modeling_bert.py
CHANGED
@@ -154,6 +154,17 @@ class BertEncoder(nn.Module):
|
|
154 |
self.layers = nn.ModuleList(
|
155 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
156 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
159 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
@@ -298,6 +309,11 @@ class BertPreTrainedModel(PreTrainedModel):
|
|
298 |
"""
|
299 |
config_class = JinaBertConfig
|
300 |
base_model_prefix = "bert"
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
|
303 |
class BertModel(BertPreTrainedModel):
|
|
|
154 |
self.layers = nn.ModuleList(
|
155 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
156 |
)
|
157 |
+
self._grad_checkpointing = False
|
158 |
+
|
159 |
+
@property
|
160 |
+
def gradient_checkpointing(self):
|
161 |
+
return self._grad_checkpointing
|
162 |
+
|
163 |
+
@gradient_checkpointing.setter
|
164 |
+
def gradient_checkpointing(self, value):
|
165 |
+
self._grad_checkpointing = value
|
166 |
+
for block in self.layers:
|
167 |
+
block.mixer.checkpointing = value
|
168 |
|
169 |
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
170 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
|
|
309 |
"""
|
310 |
config_class = JinaBertConfig
|
311 |
base_model_prefix = "bert"
|
312 |
+
supports_gradient_checkpointing = True
|
313 |
+
|
314 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
315 |
+
if isinstance(module, BertEncoder):
|
316 |
+
module.gradient_checkpointing = value
|
317 |
|
318 |
|
319 |
class BertModel(BertPreTrainedModel):
|