RL/DPO (#935)
Browse files* ipo-dpo trainer
* fix missing abstract method
* chatml template, grad checkpointing kwargs support
* fix steps calc for RL and add dataloader kwargs
* wip to fix dpo and start ppo
* more fixes
* refactor to generalize map fn
* fix dataset loop and handle argilla pref dataset
* set training args
* load reference model on seperate gpu if more than one device
* no auto upload to hub for dpo, don't add lora adapters to ref model for dpo
* fixes for rl training
* support for ipo from yaml
* set dpo training args from the config, add tests
* chore: lint
* set sequence_len for model in test
* add RLHF docs
- docs/rlhf.md +35 -0
- requirements.txt +2 -0
- src/axolotl/cli/__init__.py +90 -0
- src/axolotl/cli/train.py +5 -1
- src/axolotl/core/trainer_builder.py +103 -0
- src/axolotl/core/trainers/__init__.py +0 -0
- src/axolotl/core/trainers/trl.py +66 -0
- src/axolotl/train.py +7 -1
- src/axolotl/utils/models.py +14 -2
- src/axolotl/utils/trainer.py +7 -2
- tests/core/test_trainer_builder.py +59 -0
docs/rlhf.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RLHF (Beta)
|
2 |
+
|
3 |
+
### Overview
|
4 |
+
|
5 |
+
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
6 |
+
feedback. Various methods include, but not limited to:
|
7 |
+
|
8 |
+
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
9 |
+
- Direct Preference Optimization (DPO)
|
10 |
+
- Identity Preference Optimization (IPO)
|
11 |
+
|
12 |
+
|
13 |
+
### RLHF using Axolotl
|
14 |
+
|
15 |
+
[!IMPORTANT]
|
16 |
+
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
17 |
+
|
18 |
+
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
|
19 |
+
|
20 |
+
#### DPO
|
21 |
+
```yaml
|
22 |
+
rl: true
|
23 |
+
datasets:
|
24 |
+
- path: Intel/orca_dpo_pairs
|
25 |
+
split: train
|
26 |
+
type: intel_apply_chatml
|
27 |
+
- path: argilla/ultrafeedback-binarized-preferences
|
28 |
+
split: train
|
29 |
+
type: argilla_apply_chatml
|
30 |
+
```
|
31 |
+
|
32 |
+
#### IPO
|
33 |
+
```yaml
|
34 |
+
rl: ipo
|
35 |
+
```
|
requirements.txt
CHANGED
@@ -37,3 +37,5 @@ tensorboard
|
|
37 |
s3fs
|
38 |
gcsfs
|
39 |
# adlfs
|
|
|
|
|
|
37 |
s3fs
|
38 |
gcsfs
|
39 |
# adlfs
|
40 |
+
|
41 |
+
trl @ git+https://github.com/huggingface/trl.git@main
|
src/axolotl/cli/__init__.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
|
3 |
import importlib
|
4 |
import logging
|
|
|
5 |
import os
|
6 |
import random
|
7 |
import sys
|
@@ -16,6 +17,7 @@ import yaml
|
|
16 |
# add src to the pythonpath so we don't need to pip install this
|
17 |
from accelerate.commands.config import config_args
|
18 |
from art import text2art
|
|
|
19 |
from huggingface_hub import HfApi
|
20 |
from huggingface_hub.utils import LocalTokenNotFoundError
|
21 |
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
@@ -325,6 +327,94 @@ def load_datasets(
|
|
325 |
)
|
326 |
|
327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
def check_accelerate_default_config():
|
329 |
if Path(config_args.default_yaml_config_file).exists():
|
330 |
LOG.warning(
|
|
|
2 |
|
3 |
import importlib
|
4 |
import logging
|
5 |
+
import math
|
6 |
import os
|
7 |
import random
|
8 |
import sys
|
|
|
17 |
# add src to the pythonpath so we don't need to pip install this
|
18 |
from accelerate.commands.config import config_args
|
19 |
from art import text2art
|
20 |
+
from datasets import concatenate_datasets, load_dataset
|
21 |
from huggingface_hub import HfApi
|
22 |
from huggingface_hub.utils import LocalTokenNotFoundError
|
23 |
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
|
|
327 |
)
|
328 |
|
329 |
|
330 |
+
def load_rl_datasets(
|
331 |
+
*,
|
332 |
+
cfg: DictDefault,
|
333 |
+
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
|
334 |
+
) -> TrainDatasetMeta:
|
335 |
+
train_datasets: List[Any] = []
|
336 |
+
for i, ds_cfg in enumerate(cfg.datasets):
|
337 |
+
train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
|
338 |
+
# eval_dataset = load_dataset(
|
339 |
+
# cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
|
340 |
+
# )
|
341 |
+
eval_dataset = None
|
342 |
+
|
343 |
+
def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
344 |
+
if "system" in sample and sample["system"]:
|
345 |
+
sample["prompt"] = (
|
346 |
+
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
347 |
+
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
348 |
+
)
|
349 |
+
else:
|
350 |
+
sample[
|
351 |
+
"prompt"
|
352 |
+
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
353 |
+
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
|
354 |
+
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
|
355 |
+
return sample
|
356 |
+
|
357 |
+
def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
358 |
+
if "system" in sample and sample["system"]:
|
359 |
+
sample["prompt"] = (
|
360 |
+
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
361 |
+
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
362 |
+
)
|
363 |
+
else:
|
364 |
+
sample[
|
365 |
+
"prompt"
|
366 |
+
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
367 |
+
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
368 |
+
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
369 |
+
return sample
|
370 |
+
|
371 |
+
def apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
372 |
+
if "system" in sample and sample["system"]:
|
373 |
+
sample["prompt"] = (
|
374 |
+
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
375 |
+
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
376 |
+
)
|
377 |
+
else:
|
378 |
+
sample[
|
379 |
+
"prompt"
|
380 |
+
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
381 |
+
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
382 |
+
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
383 |
+
return sample
|
384 |
+
|
385 |
+
def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
386 |
+
if "system" in sample and sample["system"]:
|
387 |
+
sample["prompt"] = (
|
388 |
+
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
389 |
+
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
sample[
|
393 |
+
"prompt"
|
394 |
+
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
395 |
+
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
396 |
+
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
397 |
+
return sample
|
398 |
+
|
399 |
+
for i, data_set in enumerate(train_datasets):
|
400 |
+
_type = cfg.datasets[i]["type"]
|
401 |
+
ds_type_fn = locals()[_type]
|
402 |
+
train_datasets[i] = data_set.map(ds_type_fn)
|
403 |
+
train_dataset = concatenate_datasets(train_datasets)
|
404 |
+
|
405 |
+
# eval_dataset = eval_dataset.map(intel_apply_chatml)
|
406 |
+
|
407 |
+
total_num_steps = int(
|
408 |
+
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
409 |
+
)
|
410 |
+
|
411 |
+
return TrainDatasetMeta(
|
412 |
+
train_dataset=train_dataset,
|
413 |
+
eval_dataset=eval_dataset,
|
414 |
+
total_num_steps=total_num_steps,
|
415 |
+
)
|
416 |
+
|
417 |
+
|
418 |
def check_accelerate_default_config():
|
419 |
if Path(config_args.default_yaml_config_file).exists():
|
420 |
LOG.warning(
|
src/axolotl/cli/train.py
CHANGED
@@ -12,6 +12,7 @@ from axolotl.cli import (
|
|
12 |
check_user_token,
|
13 |
load_cfg,
|
14 |
load_datasets,
|
|
|
15 |
print_axolotl_text_art,
|
16 |
)
|
17 |
from axolotl.common.cli import TrainerCliArgs
|
@@ -30,7 +31,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
30 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
31 |
return_remaining_strings=True
|
32 |
)
|
33 |
-
|
|
|
|
|
|
|
34 |
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
35 |
|
36 |
|
|
|
12 |
check_user_token,
|
13 |
load_cfg,
|
14 |
load_datasets,
|
15 |
+
load_rl_datasets,
|
16 |
print_axolotl_text_art,
|
17 |
)
|
18 |
from axolotl.common.cli import TrainerCliArgs
|
|
|
31 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
32 |
return_remaining_strings=True
|
33 |
)
|
34 |
+
if parsed_cfg.rl:
|
35 |
+
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
36 |
+
else:
|
37 |
+
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
38 |
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
39 |
|
40 |
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -20,6 +20,7 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|
20 |
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
21 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
22 |
from transformers.trainer_utils import seed_worker
|
|
|
23 |
|
24 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
25 |
from axolotl.utils.callbacks import (
|
@@ -420,12 +421,21 @@ class TrainerBuilderBase(abc.ABC):
|
|
420 |
|
421 |
_train_dataset = None
|
422 |
_eval_dataset = None
|
|
|
423 |
|
424 |
def __init__(self, cfg, model, tokenizer):
|
425 |
self.cfg = cfg
|
426 |
self.model = model
|
427 |
self.tokenizer = tokenizer
|
428 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
429 |
@property
|
430 |
def train_dataset(self):
|
431 |
return self._train_dataset
|
@@ -827,3 +837,96 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
827 |
return_tensors="pt",
|
828 |
**kwargs,
|
829 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
21 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
22 |
from transformers.trainer_utils import seed_worker
|
23 |
+
from trl import DPOTrainer
|
24 |
|
25 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
26 |
from axolotl.utils.callbacks import (
|
|
|
421 |
|
422 |
_train_dataset = None
|
423 |
_eval_dataset = None
|
424 |
+
_model_ref = None
|
425 |
|
426 |
def __init__(self, cfg, model, tokenizer):
|
427 |
self.cfg = cfg
|
428 |
self.model = model
|
429 |
self.tokenizer = tokenizer
|
430 |
|
431 |
+
@property
|
432 |
+
def model_ref(self):
|
433 |
+
return self._model_ref
|
434 |
+
|
435 |
+
@model_ref.setter
|
436 |
+
def model_ref(self, model):
|
437 |
+
self._model_ref = model
|
438 |
+
|
439 |
@property
|
440 |
def train_dataset(self):
|
441 |
return self._train_dataset
|
|
|
837 |
return_tensors="pt",
|
838 |
**kwargs,
|
839 |
)
|
840 |
+
|
841 |
+
|
842 |
+
class HFDPOTrainerBuilder(TrainerBuilderBase):
|
843 |
+
"""
|
844 |
+
Trainer factory class for DPO Trainer
|
845 |
+
"""
|
846 |
+
|
847 |
+
def get_callbacks(self):
|
848 |
+
callbacks = []
|
849 |
+
return callbacks
|
850 |
+
|
851 |
+
def get_post_trainer_create_callbacks(self, trainer):
|
852 |
+
callbacks = []
|
853 |
+
return callbacks
|
854 |
+
|
855 |
+
def build_training_arguments(self, total_num_steps):
|
856 |
+
training_args_kwargs = {}
|
857 |
+
for arg in [
|
858 |
+
"adam_beta1",
|
859 |
+
"adam_beta2",
|
860 |
+
"adam_epsilon",
|
861 |
+
"dataloader_num_workers",
|
862 |
+
"dataloader_pin_memory",
|
863 |
+
]:
|
864 |
+
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
865 |
+
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
866 |
+
training_args = TrainingArguments(
|
867 |
+
per_device_train_batch_size=self.cfg.micro_batch_size,
|
868 |
+
max_steps=total_num_steps,
|
869 |
+
remove_unused_columns=False,
|
870 |
+
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
871 |
+
learning_rate=self.cfg.learning_rate,
|
872 |
+
evaluation_strategy="no",
|
873 |
+
# eval_steps=self.cfg.eval_steps,
|
874 |
+
save_strategy="steps",
|
875 |
+
save_steps=self.cfg.save_steps,
|
876 |
+
output_dir=self.cfg.output_dir,
|
877 |
+
warmup_steps=self.cfg.warmup_steps,
|
878 |
+
bf16=True,
|
879 |
+
gradient_checkpointing=self.cfg.gradient_checkpointing,
|
880 |
+
gradient_checkpointing_kwargs={"use_reentrant": False},
|
881 |
+
logging_first_step=True,
|
882 |
+
logging_steps=1,
|
883 |
+
optim=self.cfg.optimizer,
|
884 |
+
save_total_limit=self.cfg.save_total_limit or 5,
|
885 |
+
**training_args_kwargs,
|
886 |
+
)
|
887 |
+
|
888 |
+
return training_args
|
889 |
+
|
890 |
+
def build(self, total_num_steps):
|
891 |
+
training_args = self.build_training_arguments(total_num_steps)
|
892 |
+
dpo_trainer_kwargs = {}
|
893 |
+
if self.cfg.rl == "ipo":
|
894 |
+
dpo_trainer_kwargs["loss_type"] = "ipo"
|
895 |
+
if self.cfg.dpo_label_smoothing:
|
896 |
+
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
897 |
+
|
898 |
+
dpo_trainer = DPOTrainer(
|
899 |
+
self.model,
|
900 |
+
self.model_ref,
|
901 |
+
args=training_args,
|
902 |
+
beta=self.cfg.dpo_beta or 0.1,
|
903 |
+
train_dataset=self.train_dataset,
|
904 |
+
# eval_dataset=self.eval_dataset,
|
905 |
+
eval_dataset=None,
|
906 |
+
tokenizer=self.tokenizer,
|
907 |
+
max_length=self.cfg.sequence_len,
|
908 |
+
max_target_length=None,
|
909 |
+
max_prompt_length=self.cfg.sequence_len,
|
910 |
+
generate_during_eval=True,
|
911 |
+
**dpo_trainer_kwargs,
|
912 |
+
)
|
913 |
+
|
914 |
+
return dpo_trainer
|
915 |
+
|
916 |
+
|
917 |
+
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
918 |
+
"""
|
919 |
+
HF Factory class for PPO Trainer
|
920 |
+
"""
|
921 |
+
|
922 |
+
def get_callbacks(self):
|
923 |
+
callbacks = []
|
924 |
+
return callbacks
|
925 |
+
|
926 |
+
def get_post_trainer_create_callbacks(self, trainer):
|
927 |
+
callbacks = []
|
928 |
+
return callbacks
|
929 |
+
|
930 |
+
def build(self, total_num_steps):
|
931 |
+
# build PPOConfig
|
932 |
+
pass
|
src/axolotl/core/trainers/__init__.py
ADDED
File without changes
|
src/axolotl/core/trainers/trl.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
module for TRL PPO training
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
from trl import PPOTrainer
|
7 |
+
|
8 |
+
|
9 |
+
class TRLPPOTrainer(PPOTrainer):
|
10 |
+
"""
|
11 |
+
wrapper for ppo trainer to handle customizations
|
12 |
+
"""
|
13 |
+
|
14 |
+
def train(
|
15 |
+
self,
|
16 |
+
reward_pipe,
|
17 |
+
resume_from_checkpoint=None, # pylint: disable=unused-argument
|
18 |
+
):
|
19 |
+
generation_kwargs = {
|
20 |
+
"min_length": -1,
|
21 |
+
"top_k": 0.0,
|
22 |
+
"top_p": 1.0,
|
23 |
+
"do_sample": True,
|
24 |
+
"pad_token_id": self.tokenizer.eos_token_id,
|
25 |
+
"max_new_tokens": 32,
|
26 |
+
}
|
27 |
+
sent_kwargs = {
|
28 |
+
"return_all_scores": True,
|
29 |
+
"function_to_apply": "none",
|
30 |
+
"batch_size": 16,
|
31 |
+
}
|
32 |
+
|
33 |
+
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
34 |
+
enumerate(self.dataloader)
|
35 |
+
):
|
36 |
+
query_tensors = batch["input_ids"]
|
37 |
+
|
38 |
+
# generate model response
|
39 |
+
response_tensors, ref_response_tensors = self.generate(
|
40 |
+
query_tensors,
|
41 |
+
return_prompt=False,
|
42 |
+
generate_ref_response=True,
|
43 |
+
**generation_kwargs
|
44 |
+
)
|
45 |
+
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
46 |
+
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
47 |
+
|
48 |
+
# Compute sentiment score
|
49 |
+
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
50 |
+
pipe_outputs = reward_pipe(texts, **sent_kwargs)
|
51 |
+
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
52 |
+
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
|
53 |
+
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
|
54 |
+
ref_rewards = [
|
55 |
+
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
|
56 |
+
]
|
57 |
+
batch["ref_rewards"] = ref_rewards
|
58 |
+
|
59 |
+
# Run PPO step
|
60 |
+
stats = self.step(query_tensors, response_tensors, rewards)
|
61 |
+
self.log_stats(
|
62 |
+
stats,
|
63 |
+
batch,
|
64 |
+
rewards,
|
65 |
+
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
66 |
+
)
|
src/axolotl/train.py
CHANGED
@@ -61,6 +61,12 @@ def train(
|
|
61 |
msg += " and peft_config..."
|
62 |
LOG.debug(msg)
|
63 |
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
safe_serialization = cfg.save_safetensors is True
|
66 |
|
@@ -83,7 +89,7 @@ def train(
|
|
83 |
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
84 |
|
85 |
trainer = setup_trainer(
|
86 |
-
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
87 |
)
|
88 |
|
89 |
if hasattr(model, "config"):
|
|
|
61 |
msg += " and peft_config..."
|
62 |
LOG.debug(msg)
|
63 |
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
64 |
+
model_ref = None
|
65 |
+
if cfg.rl:
|
66 |
+
# load the model again for model_ref/baseline
|
67 |
+
model_ref, _ = load_model(
|
68 |
+
cfg, tokenizer, inference=cli_args.inference, reference_model=True
|
69 |
+
)
|
70 |
|
71 |
safe_serialization = cfg.save_safetensors is True
|
72 |
|
|
|
89 |
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
90 |
|
91 |
trainer = setup_trainer(
|
92 |
+
cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps
|
93 |
)
|
94 |
|
95 |
if hasattr(model, "config"):
|
src/axolotl/utils/models.py
CHANGED
@@ -200,6 +200,7 @@ def load_model(
|
|
200 |
cfg: DictDefault,
|
201 |
tokenizer: PreTrainedTokenizerBase,
|
202 |
inference: bool = False,
|
|
|
203 |
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
204 |
"""
|
205 |
Load a model for a given configuration and tokenizer.
|
@@ -290,6 +291,15 @@ def load_model(
|
|
290 |
model_kwargs["device_map"] = cfg.device_map
|
291 |
model_kwargs["max_memory"] = cfg.max_memory
|
292 |
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
if is_deepspeed_zero3_enabled():
|
295 |
del model_kwargs["device_map"]
|
@@ -560,9 +570,11 @@ def load_model(
|
|
560 |
if hasattr(module, "weight"):
|
561 |
module.to(cfg.torch_dtype)
|
562 |
|
563 |
-
|
|
|
|
|
564 |
|
565 |
-
if cfg.ddp and not load_in_8bit:
|
566 |
model.to(f"cuda:{cfg.local_rank}")
|
567 |
|
568 |
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
|
|
200 |
cfg: DictDefault,
|
201 |
tokenizer: PreTrainedTokenizerBase,
|
202 |
inference: bool = False,
|
203 |
+
reference_model: bool = False,
|
204 |
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
205 |
"""
|
206 |
Load a model for a given configuration and tokenizer.
|
|
|
291 |
model_kwargs["device_map"] = cfg.device_map
|
292 |
model_kwargs["max_memory"] = cfg.max_memory
|
293 |
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
294 |
+
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
|
295 |
+
# if cfg.rl:
|
296 |
+
# if torch.cuda.device_count() > 1:
|
297 |
+
# if reference_model:
|
298 |
+
# model_kwargs["device_map"] = "cuda:" + str(
|
299 |
+
# torch.cuda.current_device() + 1
|
300 |
+
# )
|
301 |
+
# else:
|
302 |
+
# model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
|
303 |
|
304 |
if is_deepspeed_zero3_enabled():
|
305 |
del model_kwargs["device_map"]
|
|
|
570 |
if hasattr(module, "weight"):
|
571 |
module.to(cfg.torch_dtype)
|
572 |
|
573 |
+
lora_config = None
|
574 |
+
if not reference_model or cfg.lora_model_dir:
|
575 |
+
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
576 |
|
577 |
+
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
|
578 |
model.to(f"cuda:{cfg.local_rank}")
|
579 |
|
580 |
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
src/axolotl/utils/trainer.py
CHANGED
@@ -12,7 +12,7 @@ from accelerate.logging import get_logger
|
|
12 |
from datasets import set_caching_enabled
|
13 |
from torch.utils.data import DataLoader, RandomSampler
|
14 |
|
15 |
-
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
|
16 |
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
17 |
from axolotl.utils.samplers import MultipackBatchSampler
|
18 |
|
@@ -280,7 +280,12 @@ def prepare_optim_env(cfg):
|
|
280 |
|
281 |
|
282 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
283 |
-
|
|
|
|
|
|
|
|
|
|
|
284 |
trainer_builder.train_dataset = train_dataset
|
285 |
trainer_builder.eval_dataset = eval_dataset
|
286 |
|
|
|
12 |
from datasets import set_caching_enabled
|
13 |
from torch.utils.data import DataLoader, RandomSampler
|
14 |
|
15 |
+
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
|
16 |
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
17 |
from axolotl.utils.samplers import MultipackBatchSampler
|
18 |
|
|
|
280 |
|
281 |
|
282 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
283 |
+
if cfg.rl:
|
284 |
+
trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
|
285 |
+
trainer_builder.model_ref = model[1]
|
286 |
+
else:
|
287 |
+
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
|
288 |
+
|
289 |
trainer_builder.train_dataset = train_dataset
|
290 |
trainer_builder.eval_dataset = eval_dataset
|
291 |
|
tests/core/test_trainer_builder.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
unit tests for axolotl.core.trainer_builder
|
3 |
+
"""
|
4 |
+
import pytest
|
5 |
+
|
6 |
+
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
7 |
+
from axolotl.utils.dict import DictDefault
|
8 |
+
from axolotl.utils.models import load_model, load_tokenizer
|
9 |
+
|
10 |
+
|
11 |
+
@pytest.fixture(name="cfg")
|
12 |
+
def fixture_cfg():
|
13 |
+
return DictDefault(
|
14 |
+
{
|
15 |
+
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
16 |
+
"model_type": "AutoModelForCausalLM",
|
17 |
+
"tokenizer_type": "LlamaTokenizer",
|
18 |
+
"micro_batch_size": 1,
|
19 |
+
"gradient_accumulation_steps": 1,
|
20 |
+
"learning_rate": 0.00005,
|
21 |
+
"save_steps": 100,
|
22 |
+
"output_dir": "./model-out",
|
23 |
+
"warmup_steps": 10,
|
24 |
+
"gradient_checkpointing": False,
|
25 |
+
"optimizer": "adamw_torch",
|
26 |
+
"sequence_len": 2048,
|
27 |
+
"rl": True,
|
28 |
+
"adam_beta1": 0.998,
|
29 |
+
"adam_beta2": 0.9,
|
30 |
+
"adam_epsilon": 0.00001,
|
31 |
+
"dataloader_num_workers": 1,
|
32 |
+
"dataloader_pin_memory": True,
|
33 |
+
}
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
@pytest.fixture(name="tokenizer")
|
38 |
+
def fixture_tokenizer(cfg):
|
39 |
+
return load_tokenizer(cfg)
|
40 |
+
|
41 |
+
|
42 |
+
@pytest.fixture(name="model")
|
43 |
+
def fixture_model(cfg, tokenizer):
|
44 |
+
return load_model(cfg, tokenizer)
|
45 |
+
|
46 |
+
|
47 |
+
class TestHFDPOTrainerBuilder:
|
48 |
+
"""
|
49 |
+
TestCase class for DPO trainer builder
|
50 |
+
"""
|
51 |
+
|
52 |
+
def test_build_training_arguments(self, cfg, model, tokenizer):
|
53 |
+
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
|
54 |
+
training_arguments = builder.build_training_arguments(100)
|
55 |
+
assert training_arguments.adam_beta1 == 0.998
|
56 |
+
assert training_arguments.adam_beta2 == 0.9
|
57 |
+
assert training_arguments.adam_epsilon == 0.00001
|
58 |
+
assert training_arguments.dataloader_num_workers == 1
|
59 |
+
assert training_arguments.dataloader_pin_memory is True
|