zRzRzRzRzRzRzR commited on
Commit
37fe000
1 Parent(s): 37f2196

support transformers>=4.37.2 for finetuning

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +5 -4
modeling_chatglm.py CHANGED
@@ -634,7 +634,8 @@ class GLMTransformer(torch.nn.Module):
634
  attention_mask,
635
  rotary_pos_emb,
636
  kv_caches[index],
637
- use_cache
 
638
  )
639
  else:
640
  layer_ret = layer(
@@ -697,9 +698,9 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
697
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
698
  return position_ids
699
 
700
- def _set_gradient_checkpointing(self, module, value=False):
701
- if isinstance(module, GLMTransformer):
702
- module.gradient_checkpointing = value
703
 
704
 
705
  class Embedding(torch.nn.Module):
 
634
  attention_mask,
635
  rotary_pos_emb,
636
  kv_caches[index],
637
+ use_cache,
638
+ use_reentrant=False
639
  )
640
  else:
641
  layer_ret = layer(
 
698
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
699
  return position_ids
700
 
701
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
702
+ if not self.supports_gradient_checkpointing:
703
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
704
 
705
 
706
  class Embedding(torch.nn.Module):