gradient_checkpointing_enable这个方法并没有开启梯度检查点的作用

#56
by Qleon - opened

modeling_chatglm.py

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.class.name} does not support gradient checkpointing.")

删除这个函数,替换为 chatglm2 modeling 的 _set_gradient_checkpointing 就可以了

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, GLMTransformer):
            module.gradient_checkpointing = value
zRzRzRzRzRzRzR changed discussion status to closed

Sign up or log in to comment