winglian commited on
Commit
2ea70eb
1 Parent(s): e8c8ea6

ORPO (#1419)

Browse files

* orpo trainer

* rl handling for orpo

* support for remove_unused_columns

* orpo fixes

* fix loader for orpo

* chore: lint

* fix default for remove_unused_columns

* roll ORPO into the main AxolotlTrainer so it can be compatible with some of the other techniques like relora

* better handling of system message for orpo

* revert system prompt changes for chat templtes

* no need for else condition

* split dataset parsing into it's own component

docs/rlhf.md CHANGED
@@ -34,6 +34,21 @@ datasets:
34
  rl: ipo
35
  ```
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  #### Using local dataset files
38
  ```yaml
39
  datasets:
 
34
  rl: ipo
35
  ```
36
 
37
+ #### ORPO
38
+
39
+ Paper: https://arxiv.org/abs/2403.07691
40
+
41
+ ```yaml
42
+ rl: orpo
43
+ orpo_alpha: 0.1
44
+ remove_unused_columns: false
45
+
46
+ chat_template: chatml
47
+ datasets:
48
+ - path: argilla/ultrafeedback-binarized-preferences-cleaned
49
+ type: orpo.chat_template
50
+ ```
51
+
52
  #### Using local dataset files
53
  ```yaml
54
  datasets:
src/axolotl/cli/preprocess.py CHANGED
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
54
  LOG.warning(msg)
55
  parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
56
 
57
- if parsed_cfg.rl:
58
  load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
59
  else:
60
  load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
 
54
  LOG.warning(msg)
55
  parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
56
 
57
+ if parsed_cfg.rl and parsed_cfg.rl != "orpo":
58
  load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
59
  else:
60
  load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
src/axolotl/cli/train.py CHANGED
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
47
  else:
48
  register_chatml_template()
49
 
50
- if cfg.rl:
51
  dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
52
  else:
53
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
 
47
  else:
48
  register_chatml_template()
49
 
50
+ if cfg.rl and cfg.rl != "orpo":
51
  dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
52
  else:
53
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
src/axolotl/core/trainer_builder.py CHANGED
@@ -11,10 +11,11 @@ import math
11
  import os
12
  import sys
13
  from abc import abstractmethod
 
14
  from dataclasses import dataclass, field
15
  from functools import wraps
16
  from pathlib import Path
17
- from typing import List, Optional, Type, Union
18
 
19
  import torch
20
  import transformers
@@ -200,6 +201,9 @@ class AxolotlTrainingArguments(TrainingArguments):
200
  default=False,
201
  metadata={"help": "whether this is a qlora training"},
202
  )
 
 
 
203
 
204
 
205
  class AxolotlTrainer(Trainer):
@@ -223,6 +227,9 @@ class AxolotlTrainer(Trainer):
223
  self.eval_data_collator = eval_data_collator
224
  super().__init__(*_args, **kwargs)
225
  self.train_data_collator = self.data_collator
 
 
 
226
 
227
  def create_optimizer(self):
228
  if self.args.loraplus_lr_ratio is None:
@@ -465,8 +472,112 @@ class AxolotlTrainer(Trainer):
465
  # outputs = model(**inputs)
466
  # loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
467
  # return (loss, outputs) if return_outputs else loss
 
 
468
  return super().compute_loss(model, inputs, return_outputs=return_outputs)
469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  @wraps(Trainer.push_to_hub)
471
  def push_to_hub(self, *args, **kwargs) -> str:
472
  """
@@ -527,6 +638,28 @@ class AxolotlTrainer(Trainer):
527
 
528
  return res
529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
  class AxolotlMambaTrainer(AxolotlTrainer):
532
  """
@@ -903,6 +1036,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
903
  elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
904
  training_arguments_kwargs["dataloader_drop_last"] = True
905
 
 
 
 
 
 
906
  if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
907
  # no eval set, so don't eval
908
  training_arguments_kwargs["evaluation_strategy"] = "no"
@@ -1070,6 +1208,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
1070
  training_arguments_kwargs["model_type"] = self.cfg.model_config_type
1071
  training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
1072
 
 
 
 
1073
  if self.cfg.neftune_noise_alpha is not None:
1074
  training_arguments_kwargs[
1075
  "neftune_noise_alpha"
 
11
  import os
12
  import sys
13
  from abc import abstractmethod
14
+ from collections import defaultdict
15
  from dataclasses import dataclass, field
16
  from functools import wraps
17
  from pathlib import Path
18
+ from typing import Dict, List, Literal, Optional, Type, Union
19
 
20
  import torch
21
  import transformers
 
201
  default=False,
202
  metadata={"help": "whether this is a qlora training"},
203
  )
204
+ orpo_alpha: Optional[float] = field(
205
+ default=None,
206
+ )
207
 
208
 
209
  class AxolotlTrainer(Trainer):
 
227
  self.eval_data_collator = eval_data_collator
228
  super().__init__(*_args, **kwargs)
229
  self.train_data_collator = self.data_collator
230
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
231
+ if self.args.orpo_alpha:
232
+ self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
233
 
234
  def create_optimizer(self):
235
  if self.args.loraplus_lr_ratio is None:
 
472
  # outputs = model(**inputs)
473
  # loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
474
  # return (loss, outputs) if return_outputs else loss
475
+ if self.args.orpo_alpha:
476
+ return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
477
  return super().compute_loss(model, inputs, return_outputs=return_outputs)
478
 
479
+ def orpo_compute_custom_loss(self, logits, labels):
480
+ logits = logits.contiguous()
481
+ loss = 0.0
482
+
483
+ if labels is not None:
484
+ # move labels to correct device to enable model parallelism
485
+ labels = labels.to(logits.device)
486
+ # Shift so that tokens < n predict n
487
+ shift_logits = logits[..., :-1, :].contiguous()
488
+ shift_labels = labels[..., 1:].contiguous()
489
+
490
+ # Flatten the tokens
491
+ loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
492
+ dim=-1
493
+ )
494
+
495
+ return loss
496
+
497
+ def orpo_compute_logps(
498
+ self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
499
+ ):
500
+ # Get the shape of chosen_attention_mask[:, :-1]
501
+ chosen_shape = chosen_attention_mask[:, :-1].shape
502
+
503
+ # Calculate the padding size
504
+ pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)
505
+
506
+ # Pad prompt_attention_mask with zeros to match the desired shape
507
+ prompt_attention_mask_padded = torch.nn.functional.pad(
508
+ prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
509
+ )
510
+
511
+ # Perform the subtraction operation
512
+ mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded
513
+
514
+ per_token_logps = torch.gather(
515
+ logits[:, :-1, :].log_softmax(-1),
516
+ dim=2,
517
+ index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
518
+ ).squeeze(2)
519
+ return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(
520
+ dtype=torch.float64
521
+ ) / mask.sum(dim=1).to(dtype=torch.float64)
522
+
523
+ def orpo_compute_loss(self, model, inputs, return_outputs=False):
524
+ outputs_neg = model(
525
+ **{
526
+ "input_ids": inputs["rejected_input_ids"],
527
+ "attention_mask": inputs["rejected_attention_mask"],
528
+ "labels": inputs["rejected_labels"],
529
+ },
530
+ output_hidden_states=True,
531
+ )
532
+ outputs_pos = model(
533
+ **{
534
+ "input_ids": inputs["input_ids"],
535
+ "attention_mask": inputs["attention_mask"],
536
+ "labels": inputs["labels"],
537
+ },
538
+ output_hidden_states=True,
539
+ )
540
+
541
+ # Calculate NLL loss
542
+ pos_loss = self.orpo_compute_custom_loss(
543
+ logits=outputs_pos.logits, labels=inputs["input_ids"]
544
+ )
545
+
546
+ # Calculate Log Probability
547
+ pos_prob = self.orpo_compute_logps(
548
+ prompt_attention_mask=inputs["prompt_attention_mask"],
549
+ chosen_inputs=inputs["input_ids"],
550
+ chosen_attention_mask=inputs["attention_mask"],
551
+ logits=outputs_pos.logits,
552
+ )
553
+ neg_prob = self.orpo_compute_logps(
554
+ prompt_attention_mask=inputs["prompt_attention_mask"],
555
+ chosen_inputs=inputs["rejected_input_ids"],
556
+ chosen_attention_mask=inputs["rejected_attention_mask"],
557
+ logits=outputs_neg.logits,
558
+ )
559
+
560
+ # Calculate log odds
561
+ log_odds = (pos_prob - neg_prob) - (
562
+ torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
563
+ )
564
+ sig_ratio = torch.nn.functional.sigmoid(log_odds)
565
+ ratio = torch.log(sig_ratio)
566
+
567
+ # Calculate the Final Loss
568
+ loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
569
+ dtype=torch.bfloat16
570
+ )
571
+
572
+ metrics = {}
573
+ metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
574
+ metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
575
+ metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
576
+ metrics["log_odds"] = torch.mean(log_odds).cpu().item()
577
+ self.store_metrics(metrics, train_eval="train")
578
+
579
+ return (loss, outputs_pos) if return_outputs else loss
580
+
581
  @wraps(Trainer.push_to_hub)
582
  def push_to_hub(self, *args, **kwargs) -> str:
583
  """
 
638
 
639
  return res
640
 
641
+ def log(self, logs: Dict[str, float]) -> None:
642
+ """
643
+ Log `logs` on the various objects watching training, including stored metrics.
644
+
645
+ Args:
646
+ logs (`Dict[str, float]`):
647
+ The values to log.
648
+ """
649
+ # logs either has 'loss' or 'eval_loss'
650
+ train_eval = "train" if "loss" in logs else "eval"
651
+ # Add averaged stored metrics to logs
652
+ for key, metrics in self._stored_metrics[train_eval].items():
653
+ logs[key] = torch.tensor(metrics).mean().item()
654
+ del self._stored_metrics[train_eval]
655
+ return super().log(logs)
656
+
657
+ def store_metrics(
658
+ self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
659
+ ) -> None:
660
+ for key, value in metrics.items():
661
+ self._stored_metrics[train_eval][key].append(value)
662
+
663
 
664
  class AxolotlMambaTrainer(AxolotlTrainer):
665
  """
 
1036
  elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
1037
  training_arguments_kwargs["dataloader_drop_last"] = True
1038
 
1039
+ if self.cfg.remove_unused_columns is not None:
1040
+ training_arguments_kwargs[
1041
+ "remove_unused_columns"
1042
+ ] = self.cfg.remove_unused_columns
1043
+
1044
  if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
1045
  # no eval set, so don't eval
1046
  training_arguments_kwargs["evaluation_strategy"] = "no"
 
1208
  training_arguments_kwargs["model_type"] = self.cfg.model_config_type
1209
  training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
1210
 
1211
+ if self.cfg.rl == "orpo":
1212
+ training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
1213
+
1214
  if self.cfg.neftune_noise_alpha is not None:
1215
  training_arguments_kwargs[
1216
  "neftune_noise_alpha"
src/axolotl/prompt_strategies/base.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ module for base dataset transform strategies
3
+ """
4
+
5
+ import importlib
6
+ import logging
7
+
8
+ LOG = logging.getLogger("axolotl")
9
+
10
+
11
+ def load(strategy, cfg, module_base=None, **kwargs):
12
+ try:
13
+ load_fn = strategy.split(".")[-1]
14
+ strategy = ".".join(strategy.split(".")[:-1])
15
+ mod = importlib.import_module(f".{strategy}", module_base)
16
+ func = getattr(mod, load_fn)
17
+ return func(cfg, **kwargs)
18
+ except Exception: # pylint: disable=broad-exception-caught
19
+ LOG.warning(f"unable to load strategy {strategy}")
20
+ return None
src/axolotl/prompt_strategies/dpo/__init__.py CHANGED
@@ -1,20 +1,8 @@
1
  """
2
  module for DPO style dataset transform strategies
3
  """
 
4
 
5
- import importlib
6
- import logging
7
 
8
- LOG = logging.getLogger("axolotl")
9
-
10
-
11
- def load(strategy, cfg, **kwargs):
12
- try:
13
- load_fn = strategy.split(".")[-1]
14
- strategy = ".".join(strategy.split(".")[:-1])
15
- mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
16
- func = getattr(mod, load_fn)
17
- return func(cfg, **kwargs)
18
- except Exception: # pylint: disable=broad-exception-caught
19
- LOG.warning(f"unable to load strategy {strategy}")
20
- return None
 
1
  """
2
  module for DPO style dataset transform strategies
3
  """
4
+ from functools import partial
5
 
6
+ from ..base import load as load_base
 
7
 
8
+ load = partial(load_base, module="axolotl.prompt_strategies.dpo")
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/prompt_strategies/orpo/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ module for ORPO style dataset transform strategies
3
+ """
4
+
5
+ from functools import partial
6
+
7
+ from ..base import load as load_base
8
+
9
+ load = partial(load_base, module="axolotl.prompt_strategies.orpo")
src/axolotl/prompt_strategies/orpo/chat_template.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """chatml prompt tokenization strategy for ORPO"""
2
+ from typing import Any, Dict, Generator, List, Optional, Tuple
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
7
+ from axolotl.prompters import Prompter
8
+ from axolotl.utils.chat_templates import chat_templates
9
+
10
+
11
+ class Message(BaseModel):
12
+ """message/turn"""
13
+
14
+ role: str
15
+ content: str
16
+ label: Optional[bool] = None
17
+
18
+
19
+ class MessageList(BaseModel):
20
+ """conversation"""
21
+
22
+ messages: List[Message]
23
+
24
+
25
+ def load(
26
+ tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs
27
+ ): # pylint: disable=possibly-unused-variable,unused-argument
28
+ """
29
+ chatml transforms for datasets with system, input, chosen, rejected
30
+ """
31
+
32
+ chat_template = chat_templates("chatml")
33
+ if ds_cfg and "chat_template" in ds_cfg:
34
+ chat_template = ds_cfg["chat_template"]
35
+ try:
36
+ chat_template = chat_templates(chat_template)
37
+ except ValueError:
38
+ pass
39
+
40
+ return ORPOTokenizingStrategy(
41
+ ORPOPrompter(chat_template, tokenizer),
42
+ tokenizer,
43
+ cfg.train_on_inputs,
44
+ cfg.sequence_len,
45
+ dataset_parser=ORPODatasetParsingStrategy(),
46
+ )
47
+
48
+
49
+ class ORPODatasetParsingStrategy:
50
+ """Strategy to parse chosen rejected dataset into messagelist"""
51
+
52
+ def get_chosen_conversation_thread(self, prompt) -> MessageList:
53
+ """Dataset structure mappings"""
54
+
55
+ messages: List[Message] = []
56
+ if system := prompt.get("system", None):
57
+ messages.append(Message(role="system", content=system, label=False))
58
+ messages.append(Message(role="user", content=prompt["prompt"], label=False))
59
+ messages.append(
60
+ Message(
61
+ role="assistant", content=prompt["chosen"][1]["content"], label=True
62
+ )
63
+ )
64
+ return MessageList(messages=messages)
65
+
66
+ def get_rejected_conversation_thread(self, prompt) -> MessageList:
67
+ """Dataset structure mappings"""
68
+
69
+ messages: List[Message] = []
70
+ if system := prompt.get("system", None):
71
+ messages.append(Message(role="system", content=system, label=False))
72
+ messages.append(Message(role="user", content=prompt["prompt"], label=False))
73
+ messages.append(
74
+ Message(
75
+ role="assistant", content=prompt["rejected"][1]["content"], label=True
76
+ )
77
+ )
78
+ return MessageList(messages=messages)
79
+
80
+
81
+ class ORPOTokenizingStrategy(PromptTokenizingStrategy):
82
+ """
83
+ rejected_input_ids
84
+ input_ids
85
+ rejected_attention_mask
86
+ attention_mask
87
+ rejected_labels
88
+ labels
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ *args,
94
+ dataset_parser=None,
95
+ **kwargs,
96
+ ):
97
+ super().__init__(*args, **kwargs)
98
+ self.dataset_parser = dataset_parser
99
+
100
+ def tokenize_prompt(self, prompt):
101
+ # pass the rejected prompt/row to the Prompter to get the formatted prompt
102
+ prompt_len = 0
103
+ rejected_message_list = self.dataset_parser.get_rejected_conversation_thread(
104
+ prompt
105
+ )
106
+ input_ids = []
107
+ labels = []
108
+ for _, (part, label) in enumerate(
109
+ self.prompter.build_prompt(rejected_message_list)
110
+ ):
111
+ if not part:
112
+ continue
113
+ _input_ids = self.tokenizer.encode(part, add_special_tokens=False)
114
+ prev_idx = len(input_ids)
115
+ input_ids += _input_ids[prev_idx:]
116
+ if label:
117
+ labels += input_ids[prev_idx:]
118
+ else:
119
+ labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
120
+ prompt_len = len(input_ids)
121
+ # remap the input_ids, attention_mask and labels
122
+ rejected_input_ids = input_ids
123
+ rejected_labels = labels
124
+ # pass the chosen prompt/row to the Prompter to get the formatted prompt
125
+ chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt)
126
+ input_ids = []
127
+ labels = []
128
+ for _, (part, label) in enumerate(
129
+ self.prompter.build_prompt(chosen_message_list)
130
+ ):
131
+ if not part:
132
+ continue
133
+ _input_ids = self.tokenizer.encode(part, add_special_tokens=False)
134
+ prev_idx = len(input_ids)
135
+ input_ids += _input_ids[prev_idx:]
136
+ if label:
137
+ labels += input_ids[prev_idx:]
138
+ else:
139
+ labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
140
+
141
+ return {
142
+ "rejected_input_ids": rejected_input_ids,
143
+ "rejected_labels": rejected_labels,
144
+ "rejected_attention_mask": [1] * len(rejected_labels),
145
+ "input_ids": input_ids,
146
+ "labels": labels,
147
+ "attention_mask": [1] * len(labels),
148
+ "prompt_attention_mask": [1] * prompt_len
149
+ + [0] * (len(labels) - prompt_len),
150
+ }
151
+
152
+
153
+ class ORPOPrompter(Prompter):
154
+ """Single Turn prompter for ORPO"""
155
+
156
+ def __init__(self, chat_template, tokenizer):
157
+ self.chat_template = chat_template
158
+ self.tokenizer = tokenizer
159
+
160
+ def build_prompt(
161
+ self,
162
+ message_list: MessageList,
163
+ ) -> Generator[Tuple[str, bool], None, None]:
164
+ conversation = []
165
+ for message in message_list.messages:
166
+ conversation.append(message.model_dump())
167
+ if message.role == "system":
168
+ yield self.tokenizer.apply_chat_template(
169
+ conversation,
170
+ add_generation_prompt=False,
171
+ chat_template=self.chat_template,
172
+ tokenize=False,
173
+ ), False
174
+ if message.role == "user":
175
+ yield self.tokenizer.apply_chat_template(
176
+ conversation,
177
+ add_generation_prompt=True,
178
+ chat_template=self.chat_template,
179
+ tokenize=False,
180
+ ), False
181
+ if message.role == "assistant":
182
+ yield self.tokenizer.apply_chat_template(
183
+ conversation,
184
+ add_generation_prompt=False,
185
+ chat_template=self.chat_template,
186
+ tokenize=False,
187
+ ), True
src/axolotl/train.py CHANGED
@@ -85,7 +85,7 @@ def train(
85
  model.generation_config.do_sample = True
86
 
87
  model_ref = None
88
- if cfg.rl:
89
  if cfg.adapter and not cfg.rl_adapter_ref_model:
90
  # use built-in trl autounwrap
91
  LOG.debug("Passing model_ref: None to RL trainer")
 
85
  model.generation_config.do_sample = True
86
 
87
  model_ref = None
88
+ if cfg.rl and cfg.rl != "orpo":
89
  if cfg.adapter and not cfg.rl_adapter_ref_model:
90
  # use built-in trl autounwrap
91
  LOG.debug("Passing model_ref: None to RL trainer")
src/axolotl/utils/chat_templates.py CHANGED
@@ -21,7 +21,7 @@ def chat_templates(user_choice: str):
21
  templates = {
22
  "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
23
  "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
24
- "chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
25
  "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
26
  }
27
 
 
21
  templates = {
22
  "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
23
  "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
24
+ "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
25
  "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
26
  }
27
 
src/axolotl/utils/config/__init__.py CHANGED
@@ -191,6 +191,11 @@ def normalize_cfg_datasets(cfg):
191
  f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
192
  )
193
  cfg.datasets[idx].conversation = "chatml"
 
 
 
 
 
194
 
195
 
196
  def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
 
191
  f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
192
  )
193
  cfg.datasets[idx].conversation = "chatml"
194
+ if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template:
195
+ LOG.info(
196
+ f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template"
197
+ )
198
+ cfg.datasets[idx].chat_template = "chatml"
199
 
200
 
201
  def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -124,6 +124,7 @@ class RLType(str, Enum):
124
  dpo = "dpo" # pylint: disable=invalid-name
125
  ipo = "ipo" # pylint: disable=invalid-name
126
  kto_pair = "kto_pair" # pylint: disable=invalid-name
 
127
 
128
 
129
  class ChatTemplate(str, Enum):
@@ -431,6 +432,8 @@ class AxolotlInputConfig(
431
  dataloader_prefetch_factor: Optional[int] = None
432
  dataloader_drop_last: Optional[bool] = None
433
 
 
 
434
  push_dataset_to_hub: Optional[str] = None
435
  hf_use_auth_token: Optional[bool] = None
436
 
@@ -515,6 +518,8 @@ class AxolotlInputConfig(
515
 
516
  neftune_noise_alpha: Optional[float] = None
517
 
 
 
518
  max_memory: Optional[
519
  Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
520
  ] = None
 
124
  dpo = "dpo" # pylint: disable=invalid-name
125
  ipo = "ipo" # pylint: disable=invalid-name
126
  kto_pair = "kto_pair" # pylint: disable=invalid-name
127
+ orpo = "orpo" # pylint: disable=invalid-name
128
 
129
 
130
  class ChatTemplate(str, Enum):
 
432
  dataloader_prefetch_factor: Optional[int] = None
433
  dataloader_drop_last: Optional[bool] = None
434
 
435
+ remove_unused_columns: Optional[bool] = None
436
+
437
  push_dataset_to_hub: Optional[str] = None
438
  hf_use_auth_token: Optional[bool] = None
439
 
 
518
 
519
  neftune_noise_alpha: Optional[float] = None
520
 
521
+ orpo_alpha: Optional[float] = None
522
+
523
  max_memory: Optional[
524
  Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
525
  ] = None
src/axolotl/utils/freeze.py CHANGED
@@ -3,7 +3,7 @@ module to freeze/unfreeze parameters by name
3
  """
4
  import logging
5
  import re
6
- from typing import Callable, List, Tuple
7
 
8
  from axolotl.utils.distributed import is_main_process
9
 
@@ -99,7 +99,7 @@ def _invert_ranges(
99
 
100
 
101
  def _merge_ranges(
102
- given_ranges: List[Tuple[int, int | None]], layer_size: int
103
  ) -> List[Tuple[int, int]]:
104
  """
105
  Merges overlapping ranges and sorts the given ranges.
@@ -194,7 +194,9 @@ class LayerNamePattern:
194
  """
195
  return self.name_regex.match(name) is not None
196
 
197
- def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]:
 
 
198
  """
199
  Extracts the range pattern from the given pattern.
200
 
 
3
  """
4
  import logging
5
  import re
6
+ from typing import Callable, List, Tuple, Union
7
 
8
  from axolotl.utils.distributed import is_main_process
9
 
 
99
 
100
 
101
  def _merge_ranges(
102
+ given_ranges: List[Tuple[int, Union[int, None]]], layer_size: int
103
  ) -> List[Tuple[int, int]]:
104
  """
105
  Merges overlapping ranges and sorts the given ranges.
 
194
  """
195
  return self.name_regex.match(name) is not None
196
 
197
+ def _parse_pattern(
198
+ self, pattern: str
199
+ ) -> Tuple[str, Union[Tuple[int, Union[int, None]], None]]:
200
  """
201
  Extracts the range pattern from the given pattern.
202
 
tests/test_prompt_tokenizers.py CHANGED
@@ -8,7 +8,8 @@ from pathlib import Path
8
  from typing import Optional
9
 
10
  import pytest
11
- from transformers import AutoTokenizer, LlamaTokenizer
 
12
 
13
  from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
14
  from axolotl.prompt_strategies.alpaca_w_system import (
@@ -19,12 +20,14 @@ from axolotl.prompt_strategies.llama2_chat import (
19
  Llama2ChatPrompter,
20
  LLama2ChatTokenizingStrategy,
21
  )
 
22
  from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
23
  from axolotl.prompt_tokenizers import (
24
  AlpacaPromptTokenizingStrategy,
25
  ShareGPTPromptTokenizingStrategy,
26
  )
27
  from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
 
28
 
29
  LOG = logging.getLogger("axolotl")
30
 
@@ -446,5 +449,57 @@ If a question does not make any sense, or is not factually coherent, explain why
446
  )
447
 
448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
  if __name__ == "__main__":
450
  unittest.main()
 
8
  from typing import Optional
9
 
10
  import pytest
11
+ from datasets import load_dataset
12
+ from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
13
 
14
  from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
15
  from axolotl.prompt_strategies.alpaca_w_system import (
 
20
  Llama2ChatPrompter,
21
  LLama2ChatTokenizingStrategy,
22
  )
23
+ from axolotl.prompt_strategies.orpo.chat_template import load
24
  from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
25
  from axolotl.prompt_tokenizers import (
26
  AlpacaPromptTokenizingStrategy,
27
  ShareGPTPromptTokenizingStrategy,
28
  )
29
  from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
30
+ from axolotl.utils.dict import DictDefault
31
 
32
  LOG = logging.getLogger("axolotl")
33
 
 
449
  )
450
 
451
 
452
+ class OrpoTokenizationTest(unittest.TestCase):
453
+ """test case for the ORPO tokenization"""
454
+
455
+ def setUp(self) -> None:
456
+ # pylint: disable=duplicate-code
457
+ tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
458
+ tokenizer.add_special_tokens(
459
+ {
460
+ "eos_token": AddedToken(
461
+ "<|im_end|>", rstrip=False, lstrip=False, normalized=False
462
+ )
463
+ }
464
+ )
465
+ tokenizer.add_tokens(
466
+ [
467
+ AddedToken(
468
+ "<|im_start|>", rstrip=False, lstrip=False, normalized=False
469
+ ),
470
+ ]
471
+ )
472
+ self.tokenizer = tokenizer
473
+ self.dataset = load_dataset(
474
+ "argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
475
+ ).select([0])
476
+
477
+ def test_orpo_integration(self):
478
+ strat = load(
479
+ self.tokenizer,
480
+ DictDefault({"train_on_inputs": False}),
481
+ DictDefault({"chat_template": "chatml"}),
482
+ )
483
+ res = strat.tokenize_prompt(self.dataset[0])
484
+ assert "rejected_input_ids" in res
485
+ assert "rejected_labels" in res
486
+ assert "input_ids" in res
487
+ assert "labels" in res
488
+ assert "prompt_attention_mask" in res
489
+
490
+ assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
491
+ assert len(res["input_ids"]) == len(res["labels"])
492
+ assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
493
+
494
+ assert res["rejected_labels"][0] == -100
495
+ assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
496
+
497
+ assert res["labels"][0] == -100
498
+ assert res["input_ids"][-1] == res["labels"][-1]
499
+
500
+ assert res["prompt_attention_mask"][0] == 1
501
+ assert res["prompt_attention_mask"][-1] == 0
502
+
503
+
504
  if __name__ == "__main__":
505
  unittest.main()