ORPO (#1419)
Browse files* orpo trainer
* rl handling for orpo
* support for remove_unused_columns
* orpo fixes
* fix loader for orpo
* chore: lint
* fix default for remove_unused_columns
* roll ORPO into the main AxolotlTrainer so it can be compatible with some of the other techniques like relora
* better handling of system message for orpo
* revert system prompt changes for chat templtes
* no need for else condition
* split dataset parsing into it's own component
- docs/rlhf.md +15 -0
- src/axolotl/cli/preprocess.py +1 -1
- src/axolotl/cli/train.py +1 -1
- src/axolotl/core/trainer_builder.py +142 -1
- src/axolotl/prompt_strategies/base.py +20 -0
- src/axolotl/prompt_strategies/dpo/__init__.py +3 -15
- src/axolotl/prompt_strategies/orpo/__init__.py +9 -0
- src/axolotl/prompt_strategies/orpo/chat_template.py +187 -0
- src/axolotl/train.py +1 -1
- src/axolotl/utils/chat_templates.py +1 -1
- src/axolotl/utils/config/__init__.py +5 -0
- src/axolotl/utils/config/models/input/v0_4_1/__init__.py +5 -0
- src/axolotl/utils/freeze.py +5 -3
- tests/test_prompt_tokenizers.py +56 -1
docs/rlhf.md
CHANGED
@@ -34,6 +34,21 @@ datasets:
|
|
34 |
rl: ipo
|
35 |
```
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
#### Using local dataset files
|
38 |
```yaml
|
39 |
datasets:
|
|
|
34 |
rl: ipo
|
35 |
```
|
36 |
|
37 |
+
#### ORPO
|
38 |
+
|
39 |
+
Paper: https://arxiv.org/abs/2403.07691
|
40 |
+
|
41 |
+
```yaml
|
42 |
+
rl: orpo
|
43 |
+
orpo_alpha: 0.1
|
44 |
+
remove_unused_columns: false
|
45 |
+
|
46 |
+
chat_template: chatml
|
47 |
+
datasets:
|
48 |
+
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
49 |
+
type: orpo.chat_template
|
50 |
+
```
|
51 |
+
|
52 |
#### Using local dataset files
|
53 |
```yaml
|
54 |
datasets:
|
src/axolotl/cli/preprocess.py
CHANGED
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|
54 |
LOG.warning(msg)
|
55 |
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
56 |
|
57 |
-
if parsed_cfg.rl:
|
58 |
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
59 |
else:
|
60 |
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
|
|
54 |
LOG.warning(msg)
|
55 |
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
56 |
|
57 |
+
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
58 |
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
59 |
else:
|
60 |
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
src/axolotl/cli/train.py
CHANGED
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|
47 |
else:
|
48 |
register_chatml_template()
|
49 |
|
50 |
-
if cfg.rl:
|
51 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
52 |
else:
|
53 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
47 |
else:
|
48 |
register_chatml_template()
|
49 |
|
50 |
+
if cfg.rl and cfg.rl != "orpo":
|
51 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
52 |
else:
|
53 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -11,10 +11,11 @@ import math
|
|
11 |
import os
|
12 |
import sys
|
13 |
from abc import abstractmethod
|
|
|
14 |
from dataclasses import dataclass, field
|
15 |
from functools import wraps
|
16 |
from pathlib import Path
|
17 |
-
from typing import List, Optional, Type, Union
|
18 |
|
19 |
import torch
|
20 |
import transformers
|
@@ -200,6 +201,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
200 |
default=False,
|
201 |
metadata={"help": "whether this is a qlora training"},
|
202 |
)
|
|
|
|
|
|
|
203 |
|
204 |
|
205 |
class AxolotlTrainer(Trainer):
|
@@ -223,6 +227,9 @@ class AxolotlTrainer(Trainer):
|
|
223 |
self.eval_data_collator = eval_data_collator
|
224 |
super().__init__(*_args, **kwargs)
|
225 |
self.train_data_collator = self.data_collator
|
|
|
|
|
|
|
226 |
|
227 |
def create_optimizer(self):
|
228 |
if self.args.loraplus_lr_ratio is None:
|
@@ -465,8 +472,112 @@ class AxolotlTrainer(Trainer):
|
|
465 |
# outputs = model(**inputs)
|
466 |
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
467 |
# return (loss, outputs) if return_outputs else loss
|
|
|
|
|
468 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
@wraps(Trainer.push_to_hub)
|
471 |
def push_to_hub(self, *args, **kwargs) -> str:
|
472 |
"""
|
@@ -527,6 +638,28 @@ class AxolotlTrainer(Trainer):
|
|
527 |
|
528 |
return res
|
529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
|
531 |
class AxolotlMambaTrainer(AxolotlTrainer):
|
532 |
"""
|
@@ -903,6 +1036,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
903 |
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
904 |
training_arguments_kwargs["dataloader_drop_last"] = True
|
905 |
|
|
|
|
|
|
|
|
|
|
|
906 |
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
907 |
# no eval set, so don't eval
|
908 |
training_arguments_kwargs["evaluation_strategy"] = "no"
|
@@ -1070,6 +1208,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
1070 |
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
1071 |
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
1072 |
|
|
|
|
|
|
|
1073 |
if self.cfg.neftune_noise_alpha is not None:
|
1074 |
training_arguments_kwargs[
|
1075 |
"neftune_noise_alpha"
|
|
|
11 |
import os
|
12 |
import sys
|
13 |
from abc import abstractmethod
|
14 |
+
from collections import defaultdict
|
15 |
from dataclasses import dataclass, field
|
16 |
from functools import wraps
|
17 |
from pathlib import Path
|
18 |
+
from typing import Dict, List, Literal, Optional, Type, Union
|
19 |
|
20 |
import torch
|
21 |
import transformers
|
|
|
201 |
default=False,
|
202 |
metadata={"help": "whether this is a qlora training"},
|
203 |
)
|
204 |
+
orpo_alpha: Optional[float] = field(
|
205 |
+
default=None,
|
206 |
+
)
|
207 |
|
208 |
|
209 |
class AxolotlTrainer(Trainer):
|
|
|
227 |
self.eval_data_collator = eval_data_collator
|
228 |
super().__init__(*_args, **kwargs)
|
229 |
self.train_data_collator = self.data_collator
|
230 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
231 |
+
if self.args.orpo_alpha:
|
232 |
+
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
233 |
|
234 |
def create_optimizer(self):
|
235 |
if self.args.loraplus_lr_ratio is None:
|
|
|
472 |
# outputs = model(**inputs)
|
473 |
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
474 |
# return (loss, outputs) if return_outputs else loss
|
475 |
+
if self.args.orpo_alpha:
|
476 |
+
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
477 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
478 |
|
479 |
+
def orpo_compute_custom_loss(self, logits, labels):
|
480 |
+
logits = logits.contiguous()
|
481 |
+
loss = 0.0
|
482 |
+
|
483 |
+
if labels is not None:
|
484 |
+
# move labels to correct device to enable model parallelism
|
485 |
+
labels = labels.to(logits.device)
|
486 |
+
# Shift so that tokens < n predict n
|
487 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
488 |
+
shift_labels = labels[..., 1:].contiguous()
|
489 |
+
|
490 |
+
# Flatten the tokens
|
491 |
+
loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
|
492 |
+
dim=-1
|
493 |
+
)
|
494 |
+
|
495 |
+
return loss
|
496 |
+
|
497 |
+
def orpo_compute_logps(
|
498 |
+
self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
|
499 |
+
):
|
500 |
+
# Get the shape of chosen_attention_mask[:, :-1]
|
501 |
+
chosen_shape = chosen_attention_mask[:, :-1].shape
|
502 |
+
|
503 |
+
# Calculate the padding size
|
504 |
+
pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)
|
505 |
+
|
506 |
+
# Pad prompt_attention_mask with zeros to match the desired shape
|
507 |
+
prompt_attention_mask_padded = torch.nn.functional.pad(
|
508 |
+
prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
|
509 |
+
)
|
510 |
+
|
511 |
+
# Perform the subtraction operation
|
512 |
+
mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded
|
513 |
+
|
514 |
+
per_token_logps = torch.gather(
|
515 |
+
logits[:, :-1, :].log_softmax(-1),
|
516 |
+
dim=2,
|
517 |
+
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
518 |
+
).squeeze(2)
|
519 |
+
return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(
|
520 |
+
dtype=torch.float64
|
521 |
+
) / mask.sum(dim=1).to(dtype=torch.float64)
|
522 |
+
|
523 |
+
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
524 |
+
outputs_neg = model(
|
525 |
+
**{
|
526 |
+
"input_ids": inputs["rejected_input_ids"],
|
527 |
+
"attention_mask": inputs["rejected_attention_mask"],
|
528 |
+
"labels": inputs["rejected_labels"],
|
529 |
+
},
|
530 |
+
output_hidden_states=True,
|
531 |
+
)
|
532 |
+
outputs_pos = model(
|
533 |
+
**{
|
534 |
+
"input_ids": inputs["input_ids"],
|
535 |
+
"attention_mask": inputs["attention_mask"],
|
536 |
+
"labels": inputs["labels"],
|
537 |
+
},
|
538 |
+
output_hidden_states=True,
|
539 |
+
)
|
540 |
+
|
541 |
+
# Calculate NLL loss
|
542 |
+
pos_loss = self.orpo_compute_custom_loss(
|
543 |
+
logits=outputs_pos.logits, labels=inputs["input_ids"]
|
544 |
+
)
|
545 |
+
|
546 |
+
# Calculate Log Probability
|
547 |
+
pos_prob = self.orpo_compute_logps(
|
548 |
+
prompt_attention_mask=inputs["prompt_attention_mask"],
|
549 |
+
chosen_inputs=inputs["input_ids"],
|
550 |
+
chosen_attention_mask=inputs["attention_mask"],
|
551 |
+
logits=outputs_pos.logits,
|
552 |
+
)
|
553 |
+
neg_prob = self.orpo_compute_logps(
|
554 |
+
prompt_attention_mask=inputs["prompt_attention_mask"],
|
555 |
+
chosen_inputs=inputs["rejected_input_ids"],
|
556 |
+
chosen_attention_mask=inputs["rejected_attention_mask"],
|
557 |
+
logits=outputs_neg.logits,
|
558 |
+
)
|
559 |
+
|
560 |
+
# Calculate log odds
|
561 |
+
log_odds = (pos_prob - neg_prob) - (
|
562 |
+
torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
|
563 |
+
)
|
564 |
+
sig_ratio = torch.nn.functional.sigmoid(log_odds)
|
565 |
+
ratio = torch.log(sig_ratio)
|
566 |
+
|
567 |
+
# Calculate the Final Loss
|
568 |
+
loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
|
569 |
+
dtype=torch.bfloat16
|
570 |
+
)
|
571 |
+
|
572 |
+
metrics = {}
|
573 |
+
metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
|
574 |
+
metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
|
575 |
+
metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
|
576 |
+
metrics["log_odds"] = torch.mean(log_odds).cpu().item()
|
577 |
+
self.store_metrics(metrics, train_eval="train")
|
578 |
+
|
579 |
+
return (loss, outputs_pos) if return_outputs else loss
|
580 |
+
|
581 |
@wraps(Trainer.push_to_hub)
|
582 |
def push_to_hub(self, *args, **kwargs) -> str:
|
583 |
"""
|
|
|
638 |
|
639 |
return res
|
640 |
|
641 |
+
def log(self, logs: Dict[str, float]) -> None:
|
642 |
+
"""
|
643 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
644 |
+
|
645 |
+
Args:
|
646 |
+
logs (`Dict[str, float]`):
|
647 |
+
The values to log.
|
648 |
+
"""
|
649 |
+
# logs either has 'loss' or 'eval_loss'
|
650 |
+
train_eval = "train" if "loss" in logs else "eval"
|
651 |
+
# Add averaged stored metrics to logs
|
652 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
653 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
654 |
+
del self._stored_metrics[train_eval]
|
655 |
+
return super().log(logs)
|
656 |
+
|
657 |
+
def store_metrics(
|
658 |
+
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
659 |
+
) -> None:
|
660 |
+
for key, value in metrics.items():
|
661 |
+
self._stored_metrics[train_eval][key].append(value)
|
662 |
+
|
663 |
|
664 |
class AxolotlMambaTrainer(AxolotlTrainer):
|
665 |
"""
|
|
|
1036 |
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
1037 |
training_arguments_kwargs["dataloader_drop_last"] = True
|
1038 |
|
1039 |
+
if self.cfg.remove_unused_columns is not None:
|
1040 |
+
training_arguments_kwargs[
|
1041 |
+
"remove_unused_columns"
|
1042 |
+
] = self.cfg.remove_unused_columns
|
1043 |
+
|
1044 |
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
1045 |
# no eval set, so don't eval
|
1046 |
training_arguments_kwargs["evaluation_strategy"] = "no"
|
|
|
1208 |
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
1209 |
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
1210 |
|
1211 |
+
if self.cfg.rl == "orpo":
|
1212 |
+
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
|
1213 |
+
|
1214 |
if self.cfg.neftune_noise_alpha is not None:
|
1215 |
training_arguments_kwargs[
|
1216 |
"neftune_noise_alpha"
|
src/axolotl/prompt_strategies/base.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
module for base dataset transform strategies
|
3 |
+
"""
|
4 |
+
|
5 |
+
import importlib
|
6 |
+
import logging
|
7 |
+
|
8 |
+
LOG = logging.getLogger("axolotl")
|
9 |
+
|
10 |
+
|
11 |
+
def load(strategy, cfg, module_base=None, **kwargs):
|
12 |
+
try:
|
13 |
+
load_fn = strategy.split(".")[-1]
|
14 |
+
strategy = ".".join(strategy.split(".")[:-1])
|
15 |
+
mod = importlib.import_module(f".{strategy}", module_base)
|
16 |
+
func = getattr(mod, load_fn)
|
17 |
+
return func(cfg, **kwargs)
|
18 |
+
except Exception: # pylint: disable=broad-exception-caught
|
19 |
+
LOG.warning(f"unable to load strategy {strategy}")
|
20 |
+
return None
|
src/axolotl/prompt_strategies/dpo/__init__.py
CHANGED
@@ -1,20 +1,8 @@
|
|
1 |
"""
|
2 |
module for DPO style dataset transform strategies
|
3 |
"""
|
|
|
4 |
|
5 |
-
import
|
6 |
-
import logging
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
def load(strategy, cfg, **kwargs):
|
12 |
-
try:
|
13 |
-
load_fn = strategy.split(".")[-1]
|
14 |
-
strategy = ".".join(strategy.split(".")[:-1])
|
15 |
-
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
|
16 |
-
func = getattr(mod, load_fn)
|
17 |
-
return func(cfg, **kwargs)
|
18 |
-
except Exception: # pylint: disable=broad-exception-caught
|
19 |
-
LOG.warning(f"unable to load strategy {strategy}")
|
20 |
-
return None
|
|
|
1 |
"""
|
2 |
module for DPO style dataset transform strategies
|
3 |
"""
|
4 |
+
from functools import partial
|
5 |
|
6 |
+
from ..base import load as load_base
|
|
|
7 |
|
8 |
+
load = partial(load_base, module="axolotl.prompt_strategies.dpo")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/prompt_strategies/orpo/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
module for ORPO style dataset transform strategies
|
3 |
+
"""
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
from ..base import load as load_base
|
8 |
+
|
9 |
+
load = partial(load_base, module="axolotl.prompt_strategies.orpo")
|
src/axolotl/prompt_strategies/orpo/chat_template.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""chatml prompt tokenization strategy for ORPO"""
|
2 |
+
from typing import Any, Dict, Generator, List, Optional, Tuple
|
3 |
+
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
|
7 |
+
from axolotl.prompters import Prompter
|
8 |
+
from axolotl.utils.chat_templates import chat_templates
|
9 |
+
|
10 |
+
|
11 |
+
class Message(BaseModel):
|
12 |
+
"""message/turn"""
|
13 |
+
|
14 |
+
role: str
|
15 |
+
content: str
|
16 |
+
label: Optional[bool] = None
|
17 |
+
|
18 |
+
|
19 |
+
class MessageList(BaseModel):
|
20 |
+
"""conversation"""
|
21 |
+
|
22 |
+
messages: List[Message]
|
23 |
+
|
24 |
+
|
25 |
+
def load(
|
26 |
+
tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs
|
27 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
28 |
+
"""
|
29 |
+
chatml transforms for datasets with system, input, chosen, rejected
|
30 |
+
"""
|
31 |
+
|
32 |
+
chat_template = chat_templates("chatml")
|
33 |
+
if ds_cfg and "chat_template" in ds_cfg:
|
34 |
+
chat_template = ds_cfg["chat_template"]
|
35 |
+
try:
|
36 |
+
chat_template = chat_templates(chat_template)
|
37 |
+
except ValueError:
|
38 |
+
pass
|
39 |
+
|
40 |
+
return ORPOTokenizingStrategy(
|
41 |
+
ORPOPrompter(chat_template, tokenizer),
|
42 |
+
tokenizer,
|
43 |
+
cfg.train_on_inputs,
|
44 |
+
cfg.sequence_len,
|
45 |
+
dataset_parser=ORPODatasetParsingStrategy(),
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
class ORPODatasetParsingStrategy:
|
50 |
+
"""Strategy to parse chosen rejected dataset into messagelist"""
|
51 |
+
|
52 |
+
def get_chosen_conversation_thread(self, prompt) -> MessageList:
|
53 |
+
"""Dataset structure mappings"""
|
54 |
+
|
55 |
+
messages: List[Message] = []
|
56 |
+
if system := prompt.get("system", None):
|
57 |
+
messages.append(Message(role="system", content=system, label=False))
|
58 |
+
messages.append(Message(role="user", content=prompt["prompt"], label=False))
|
59 |
+
messages.append(
|
60 |
+
Message(
|
61 |
+
role="assistant", content=prompt["chosen"][1]["content"], label=True
|
62 |
+
)
|
63 |
+
)
|
64 |
+
return MessageList(messages=messages)
|
65 |
+
|
66 |
+
def get_rejected_conversation_thread(self, prompt) -> MessageList:
|
67 |
+
"""Dataset structure mappings"""
|
68 |
+
|
69 |
+
messages: List[Message] = []
|
70 |
+
if system := prompt.get("system", None):
|
71 |
+
messages.append(Message(role="system", content=system, label=False))
|
72 |
+
messages.append(Message(role="user", content=prompt["prompt"], label=False))
|
73 |
+
messages.append(
|
74 |
+
Message(
|
75 |
+
role="assistant", content=prompt["rejected"][1]["content"], label=True
|
76 |
+
)
|
77 |
+
)
|
78 |
+
return MessageList(messages=messages)
|
79 |
+
|
80 |
+
|
81 |
+
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
82 |
+
"""
|
83 |
+
rejected_input_ids
|
84 |
+
input_ids
|
85 |
+
rejected_attention_mask
|
86 |
+
attention_mask
|
87 |
+
rejected_labels
|
88 |
+
labels
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
*args,
|
94 |
+
dataset_parser=None,
|
95 |
+
**kwargs,
|
96 |
+
):
|
97 |
+
super().__init__(*args, **kwargs)
|
98 |
+
self.dataset_parser = dataset_parser
|
99 |
+
|
100 |
+
def tokenize_prompt(self, prompt):
|
101 |
+
# pass the rejected prompt/row to the Prompter to get the formatted prompt
|
102 |
+
prompt_len = 0
|
103 |
+
rejected_message_list = self.dataset_parser.get_rejected_conversation_thread(
|
104 |
+
prompt
|
105 |
+
)
|
106 |
+
input_ids = []
|
107 |
+
labels = []
|
108 |
+
for _, (part, label) in enumerate(
|
109 |
+
self.prompter.build_prompt(rejected_message_list)
|
110 |
+
):
|
111 |
+
if not part:
|
112 |
+
continue
|
113 |
+
_input_ids = self.tokenizer.encode(part, add_special_tokens=False)
|
114 |
+
prev_idx = len(input_ids)
|
115 |
+
input_ids += _input_ids[prev_idx:]
|
116 |
+
if label:
|
117 |
+
labels += input_ids[prev_idx:]
|
118 |
+
else:
|
119 |
+
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
120 |
+
prompt_len = len(input_ids)
|
121 |
+
# remap the input_ids, attention_mask and labels
|
122 |
+
rejected_input_ids = input_ids
|
123 |
+
rejected_labels = labels
|
124 |
+
# pass the chosen prompt/row to the Prompter to get the formatted prompt
|
125 |
+
chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt)
|
126 |
+
input_ids = []
|
127 |
+
labels = []
|
128 |
+
for _, (part, label) in enumerate(
|
129 |
+
self.prompter.build_prompt(chosen_message_list)
|
130 |
+
):
|
131 |
+
if not part:
|
132 |
+
continue
|
133 |
+
_input_ids = self.tokenizer.encode(part, add_special_tokens=False)
|
134 |
+
prev_idx = len(input_ids)
|
135 |
+
input_ids += _input_ids[prev_idx:]
|
136 |
+
if label:
|
137 |
+
labels += input_ids[prev_idx:]
|
138 |
+
else:
|
139 |
+
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
140 |
+
|
141 |
+
return {
|
142 |
+
"rejected_input_ids": rejected_input_ids,
|
143 |
+
"rejected_labels": rejected_labels,
|
144 |
+
"rejected_attention_mask": [1] * len(rejected_labels),
|
145 |
+
"input_ids": input_ids,
|
146 |
+
"labels": labels,
|
147 |
+
"attention_mask": [1] * len(labels),
|
148 |
+
"prompt_attention_mask": [1] * prompt_len
|
149 |
+
+ [0] * (len(labels) - prompt_len),
|
150 |
+
}
|
151 |
+
|
152 |
+
|
153 |
+
class ORPOPrompter(Prompter):
|
154 |
+
"""Single Turn prompter for ORPO"""
|
155 |
+
|
156 |
+
def __init__(self, chat_template, tokenizer):
|
157 |
+
self.chat_template = chat_template
|
158 |
+
self.tokenizer = tokenizer
|
159 |
+
|
160 |
+
def build_prompt(
|
161 |
+
self,
|
162 |
+
message_list: MessageList,
|
163 |
+
) -> Generator[Tuple[str, bool], None, None]:
|
164 |
+
conversation = []
|
165 |
+
for message in message_list.messages:
|
166 |
+
conversation.append(message.model_dump())
|
167 |
+
if message.role == "system":
|
168 |
+
yield self.tokenizer.apply_chat_template(
|
169 |
+
conversation,
|
170 |
+
add_generation_prompt=False,
|
171 |
+
chat_template=self.chat_template,
|
172 |
+
tokenize=False,
|
173 |
+
), False
|
174 |
+
if message.role == "user":
|
175 |
+
yield self.tokenizer.apply_chat_template(
|
176 |
+
conversation,
|
177 |
+
add_generation_prompt=True,
|
178 |
+
chat_template=self.chat_template,
|
179 |
+
tokenize=False,
|
180 |
+
), False
|
181 |
+
if message.role == "assistant":
|
182 |
+
yield self.tokenizer.apply_chat_template(
|
183 |
+
conversation,
|
184 |
+
add_generation_prompt=False,
|
185 |
+
chat_template=self.chat_template,
|
186 |
+
tokenize=False,
|
187 |
+
), True
|
src/axolotl/train.py
CHANGED
@@ -85,7 +85,7 @@ def train(
|
|
85 |
model.generation_config.do_sample = True
|
86 |
|
87 |
model_ref = None
|
88 |
-
if cfg.rl:
|
89 |
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
90 |
# use built-in trl autounwrap
|
91 |
LOG.debug("Passing model_ref: None to RL trainer")
|
|
|
85 |
model.generation_config.do_sample = True
|
86 |
|
87 |
model_ref = None
|
88 |
+
if cfg.rl and cfg.rl != "orpo":
|
89 |
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
90 |
# use built-in trl autounwrap
|
91 |
LOG.debug("Passing model_ref: None to RL trainer")
|
src/axolotl/utils/chat_templates.py
CHANGED
@@ -21,7 +21,7 @@ def chat_templates(user_choice: str):
|
|
21 |
templates = {
|
22 |
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
23 |
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
24 |
-
"chatml": "{% if
|
25 |
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
26 |
}
|
27 |
|
|
|
21 |
templates = {
|
22 |
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
23 |
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
24 |
+
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
25 |
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
26 |
}
|
27 |
|
src/axolotl/utils/config/__init__.py
CHANGED
@@ -191,6 +191,11 @@ def normalize_cfg_datasets(cfg):
|
|
191 |
f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
|
192 |
)
|
193 |
cfg.datasets[idx].conversation = "chatml"
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
|
196 |
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
|
|
191 |
f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
|
192 |
)
|
193 |
cfg.datasets[idx].conversation = "chatml"
|
194 |
+
if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template:
|
195 |
+
LOG.info(
|
196 |
+
f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template"
|
197 |
+
)
|
198 |
+
cfg.datasets[idx].chat_template = "chatml"
|
199 |
|
200 |
|
201 |
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -124,6 +124,7 @@ class RLType(str, Enum):
|
|
124 |
dpo = "dpo" # pylint: disable=invalid-name
|
125 |
ipo = "ipo" # pylint: disable=invalid-name
|
126 |
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
|
|
127 |
|
128 |
|
129 |
class ChatTemplate(str, Enum):
|
@@ -431,6 +432,8 @@ class AxolotlInputConfig(
|
|
431 |
dataloader_prefetch_factor: Optional[int] = None
|
432 |
dataloader_drop_last: Optional[bool] = None
|
433 |
|
|
|
|
|
434 |
push_dataset_to_hub: Optional[str] = None
|
435 |
hf_use_auth_token: Optional[bool] = None
|
436 |
|
@@ -515,6 +518,8 @@ class AxolotlInputConfig(
|
|
515 |
|
516 |
neftune_noise_alpha: Optional[float] = None
|
517 |
|
|
|
|
|
518 |
max_memory: Optional[
|
519 |
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
520 |
] = None
|
|
|
124 |
dpo = "dpo" # pylint: disable=invalid-name
|
125 |
ipo = "ipo" # pylint: disable=invalid-name
|
126 |
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
127 |
+
orpo = "orpo" # pylint: disable=invalid-name
|
128 |
|
129 |
|
130 |
class ChatTemplate(str, Enum):
|
|
|
432 |
dataloader_prefetch_factor: Optional[int] = None
|
433 |
dataloader_drop_last: Optional[bool] = None
|
434 |
|
435 |
+
remove_unused_columns: Optional[bool] = None
|
436 |
+
|
437 |
push_dataset_to_hub: Optional[str] = None
|
438 |
hf_use_auth_token: Optional[bool] = None
|
439 |
|
|
|
518 |
|
519 |
neftune_noise_alpha: Optional[float] = None
|
520 |
|
521 |
+
orpo_alpha: Optional[float] = None
|
522 |
+
|
523 |
max_memory: Optional[
|
524 |
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
525 |
] = None
|
src/axolotl/utils/freeze.py
CHANGED
@@ -3,7 +3,7 @@ module to freeze/unfreeze parameters by name
|
|
3 |
"""
|
4 |
import logging
|
5 |
import re
|
6 |
-
from typing import Callable, List, Tuple
|
7 |
|
8 |
from axolotl.utils.distributed import is_main_process
|
9 |
|
@@ -99,7 +99,7 @@ def _invert_ranges(
|
|
99 |
|
100 |
|
101 |
def _merge_ranges(
|
102 |
-
given_ranges: List[Tuple[int, int
|
103 |
) -> List[Tuple[int, int]]:
|
104 |
"""
|
105 |
Merges overlapping ranges and sorts the given ranges.
|
@@ -194,7 +194,9 @@ class LayerNamePattern:
|
|
194 |
"""
|
195 |
return self.name_regex.match(name) is not None
|
196 |
|
197 |
-
def _parse_pattern(
|
|
|
|
|
198 |
"""
|
199 |
Extracts the range pattern from the given pattern.
|
200 |
|
|
|
3 |
"""
|
4 |
import logging
|
5 |
import re
|
6 |
+
from typing import Callable, List, Tuple, Union
|
7 |
|
8 |
from axolotl.utils.distributed import is_main_process
|
9 |
|
|
|
99 |
|
100 |
|
101 |
def _merge_ranges(
|
102 |
+
given_ranges: List[Tuple[int, Union[int, None]]], layer_size: int
|
103 |
) -> List[Tuple[int, int]]:
|
104 |
"""
|
105 |
Merges overlapping ranges and sorts the given ranges.
|
|
|
194 |
"""
|
195 |
return self.name_regex.match(name) is not None
|
196 |
|
197 |
+
def _parse_pattern(
|
198 |
+
self, pattern: str
|
199 |
+
) -> Tuple[str, Union[Tuple[int, Union[int, None]], None]]:
|
200 |
"""
|
201 |
Extracts the range pattern from the given pattern.
|
202 |
|
tests/test_prompt_tokenizers.py
CHANGED
@@ -8,7 +8,8 @@ from pathlib import Path
|
|
8 |
from typing import Optional
|
9 |
|
10 |
import pytest
|
11 |
-
from
|
|
|
12 |
|
13 |
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
14 |
from axolotl.prompt_strategies.alpaca_w_system import (
|
@@ -19,12 +20,14 @@ from axolotl.prompt_strategies.llama2_chat import (
|
|
19 |
Llama2ChatPrompter,
|
20 |
LLama2ChatTokenizingStrategy,
|
21 |
)
|
|
|
22 |
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
23 |
from axolotl.prompt_tokenizers import (
|
24 |
AlpacaPromptTokenizingStrategy,
|
25 |
ShareGPTPromptTokenizingStrategy,
|
26 |
)
|
27 |
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
|
|
28 |
|
29 |
LOG = logging.getLogger("axolotl")
|
30 |
|
@@ -446,5 +449,57 @@ If a question does not make any sense, or is not factually coherent, explain why
|
|
446 |
)
|
447 |
|
448 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
if __name__ == "__main__":
|
450 |
unittest.main()
|
|
|
8 |
from typing import Optional
|
9 |
|
10 |
import pytest
|
11 |
+
from datasets import load_dataset
|
12 |
+
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
13 |
|
14 |
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
15 |
from axolotl.prompt_strategies.alpaca_w_system import (
|
|
|
20 |
Llama2ChatPrompter,
|
21 |
LLama2ChatTokenizingStrategy,
|
22 |
)
|
23 |
+
from axolotl.prompt_strategies.orpo.chat_template import load
|
24 |
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
25 |
from axolotl.prompt_tokenizers import (
|
26 |
AlpacaPromptTokenizingStrategy,
|
27 |
ShareGPTPromptTokenizingStrategy,
|
28 |
)
|
29 |
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
30 |
+
from axolotl.utils.dict import DictDefault
|
31 |
|
32 |
LOG = logging.getLogger("axolotl")
|
33 |
|
|
|
449 |
)
|
450 |
|
451 |
|
452 |
+
class OrpoTokenizationTest(unittest.TestCase):
|
453 |
+
"""test case for the ORPO tokenization"""
|
454 |
+
|
455 |
+
def setUp(self) -> None:
|
456 |
+
# pylint: disable=duplicate-code
|
457 |
+
tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
458 |
+
tokenizer.add_special_tokens(
|
459 |
+
{
|
460 |
+
"eos_token": AddedToken(
|
461 |
+
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
462 |
+
)
|
463 |
+
}
|
464 |
+
)
|
465 |
+
tokenizer.add_tokens(
|
466 |
+
[
|
467 |
+
AddedToken(
|
468 |
+
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
|
469 |
+
),
|
470 |
+
]
|
471 |
+
)
|
472 |
+
self.tokenizer = tokenizer
|
473 |
+
self.dataset = load_dataset(
|
474 |
+
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
475 |
+
).select([0])
|
476 |
+
|
477 |
+
def test_orpo_integration(self):
|
478 |
+
strat = load(
|
479 |
+
self.tokenizer,
|
480 |
+
DictDefault({"train_on_inputs": False}),
|
481 |
+
DictDefault({"chat_template": "chatml"}),
|
482 |
+
)
|
483 |
+
res = strat.tokenize_prompt(self.dataset[0])
|
484 |
+
assert "rejected_input_ids" in res
|
485 |
+
assert "rejected_labels" in res
|
486 |
+
assert "input_ids" in res
|
487 |
+
assert "labels" in res
|
488 |
+
assert "prompt_attention_mask" in res
|
489 |
+
|
490 |
+
assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
|
491 |
+
assert len(res["input_ids"]) == len(res["labels"])
|
492 |
+
assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
|
493 |
+
|
494 |
+
assert res["rejected_labels"][0] == -100
|
495 |
+
assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
|
496 |
+
|
497 |
+
assert res["labels"][0] == -100
|
498 |
+
assert res["input_ids"][-1] == res["labels"][-1]
|
499 |
+
|
500 |
+
assert res["prompt_attention_mask"][0] == 1
|
501 |
+
assert res["prompt_attention_mask"][-1] == 0
|
502 |
+
|
503 |
+
|
504 |
if __name__ == "__main__":
|
505 |
unittest.main()
|