Spaces:
Runtime error
Runtime error
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import transformers | |
from transformers import ( | |
PreTrainedModel, | |
TrainingArguments, | |
DataCollator, | |
PreTrainedTokenizerBase, | |
EvalPrediction, | |
TrainerCallback, | |
) | |
from typing import Callable, Dict, List, Optional, Tuple, Union, Any | |
from torch import nn | |
from torch.utils.data import Dataset, DataLoader | |
from transformers.utils import ( | |
logging, | |
) | |
from typing import Optional | |
import os | |
import torch | |
logger = logging.get_logger(__name__) | |
# Name of the files used for checkpointing | |
TRAINING_ARGS_NAME = "training_args.bin" | |
TRAINER_STATE_NAME = "trainer_state.json" | |
OPTIMIZER_NAME = "optimizer.pt" | |
SCHEDULER_NAME = "scheduler.pt" | |
SCALER_NAME = "scaler.pt" | |
class Trainer(transformers.Trainer): | |
""" | |
主要修改逻辑: 通过传入compute_loss, 支持自定义loss计算方式. | |
""" | |
def __init__( | |
self, | |
model: Union[PreTrainedModel, nn.Module] = None, | |
args: TrainingArguments = None, | |
data_collator: Optional[DataCollator] = None, | |
train_dataset: Optional[Dataset] = None, | |
eval_dataset: Optional[Dataset] = None, | |
tokenizer: Optional[PreTrainedTokenizerBase] = None, | |
model_init: Callable[[], PreTrainedModel] = None, | |
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, | |
callbacks: Optional[List[TrainerCallback]] = None, | |
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), | |
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, | |
compute_loss=None, | |
): | |
super(Trainer, self).__init__( | |
model=model, | |
args=args, | |
data_collator=data_collator, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
tokenizer=tokenizer, | |
model_init=model_init, | |
compute_metrics=compute_metrics, | |
callbacks=callbacks, | |
optimizers=optimizers, | |
preprocess_logits_for_metrics=preprocess_logits_for_metrics, | |
) | |
self.loss_func = compute_loss | |
def compute_loss(self, model, inputs, return_outputs=False): | |
""" | |
重写loss的计算方式 | |
How the loss is computed by Trainer. By default, all models return the loss in the first element. | |
Subclass and override for custom behavior. | |
""" | |
return self.loss_func(model, inputs, self.args, return_outputs) | |
class LoRATrainer(Trainer): | |
""" | |
修改checkpoint的保存逻辑, 只保存lora. | |
""" | |
def _save(self, output_dir: Optional[str] = None, state_dict=None): | |
# If we are executing this function, we are the process zero, so we don't check for that. | |
output_dir = output_dir if output_dir is not None else self.args.output_dir | |
os.makedirs(output_dir, exist_ok=True) | |
logger.info(f"Saving model checkpoint to {output_dir}") | |
# 保存lora权重和配置 | |
self.model.save_pretrained( | |
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors | |
) | |
if self.tokenizer is not None: | |
self.tokenizer.save_pretrained(output_dir) | |
# Good practice: save your training arguments together with the trained model | |
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) | |
if __name__ == '__main__': | |
pass | |