Update modeling_jais.py
Browse files- modeling_jais.py +0 -5
modeling_jais.py
CHANGED
@@ -535,11 +535,6 @@ class JAISPreTrainedModel(PreTrainedModel):
|
|
535 |
stddev = self.config.initializer_range * mup_init_scale / math.sqrt(2 * self.config.n_layer)
|
536 |
p.data.normal_(mean=0.0, std=stddev)
|
537 |
|
538 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
539 |
-
if isinstance(module, JAISModel):
|
540 |
-
module.gradient_checkpointing = value
|
541 |
-
|
542 |
-
|
543 |
JAIS_START_DOCSTRING = r"""
|
544 |
|
545 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
|
535 |
stddev = self.config.initializer_range * mup_init_scale / math.sqrt(2 * self.config.n_layer)
|
536 |
p.data.normal_(mean=0.0, std=stddev)
|
537 |
|
|
|
|
|
|
|
|
|
|
|
538 |
JAIS_START_DOCSTRING = r"""
|
539 |
|
540 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|