Update modeling_gpt_refact.py
Browse files- modeling_gpt_refact.py +2 -2
modeling_gpt_refact.py
CHANGED
@@ -508,9 +508,9 @@ class GPTRefactForCausalLM(GPTRefactPreTrainedModel):
|
|
508 |
import transformers
|
509 |
from packaging import version
|
510 |
|
511 |
-
def _set_gradient_checkpointing(module,
|
512 |
if isinstance(module, GPTRefactModel):
|
513 |
-
module.gradient_checkpointing =
|
514 |
|
515 |
v = version.parse(transformers.__version__)
|
516 |
if v.major <= 4 and v.minor < 35:
|
|
|
508 |
import transformers
|
509 |
from packaging import version
|
510 |
|
511 |
+
def _set_gradient_checkpointing(module, value=False):
|
512 |
if isinstance(module, GPTRefactModel):
|
513 |
+
module.gradient_checkpointing = value
|
514 |
|
515 |
v = version.parse(transformers.__version__)
|
516 |
if v.major <= 4 and v.minor < 35:
|