add gradient checkpointing for the final_layernorm module.
Browse filesWithout this, when tuning with LoRA + gradient checkpointing, the last transformer layer, i.e., layer-27's LoRA weights won't be updated!
For example, if we use this callback to log the weight change of LoRA weights in each layer, we will find that no weight update for the last layer in TensorBoard.
```
class ParamsTensorBoardCallback(TensorBoardCallback):
def __init__(self, tb_writer=None, params=None, process_name=lambda x:x):
super().__init__(tb_writer)
self.params = params
self._process_name = process_name
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % args.logging_steps == 0:
dict_ = {}
model = kwargs["model"]
for name in self.params:
param = model.get_parameter(name)
param = param.flatten()
name_p = self._process_name(name)
dict_tmp = {
f"{name_p}_mean": param.mean().item(),
f"{name_p}_max": param.max().item(),
f"{name_p}_q75": param.quantile(0.75).item(),
f"{name_p}_q25": param.quantile(0.25).item(),
f"{name_p}_min": param.min().item(),
f"{name_p}_median": param.median().item(),
f"{name_p}_std": param.std().item(),
}
dict_.update(dict_tmp)
self.on_log(args, state, control, logs=dict_, **kwargs)
def get_params_for_logging(model):
ls_params = []
for name, param in model.named_parameters():
if param.requires_grad:
ls_params.append(name)
return ls_params
ls_params = get_params_for_logging(model)
tb_cb = ParamsTensorBoardCallback(
None, ls_params, process_name=getattr(utils, param_name_trimmer_name)()
)
trainer = Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=args,
data_collator=data_collator,
callbacks=[tb_cb]
)
```
- modeling_chatglm.py +4 -1
@@ -1012,7 +1012,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
1012 |
all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
|
1013 |
|
1014 |
# Final layer norm.
|
1015 |
-
|
|
|
|
|
|
|
1016 |
|
1017 |
if output_hidden_states:
|
1018 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
1012 |
all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
|
1013 |
|
1014 |
# Final layer norm.
|
1015 |
+
if self.gradient_checkpointing and self.training:
|
1016 |
+
hidden_states = torch.utils.checkpoint.checkpoint(self.final_layernorm, hidden_states)
|
1017 |
+
else:
|
1018 |
+
hidden_states = self.final_layernorm(hidden_states)
|
1019 |
|
1020 |
if output_hidden_states:
|
1021 |
all_hidden_states = all_hidden_states + (hidden_states,)
|