burtenshaw HF staff commited on
Commit
4ad32d0
1 Parent(s): 8170922

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ wandb
2
+ src/__pycache__
3
+ scripts/run_orpo.sh
4
+ src/accelerate/fsdp.yaml
5
+ scripts/run_orpo.sh
6
+ src/__pycache__/args.cpython-311.pyc
7
+ src/__pycache__/utils.cpython-311.pyc
8
+ src/accelerate/fsdp.yaml
README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # **ORPO**
2
+
3
+ ### **`Updates (24.03.25)`**
4
+ - [X] Sample script for ORPOTrainer in 🤗<a class="link" href="https://github.com/huggingface/trl">TRL</a> is added to `trl/test_orpo_trainer_demo.py`
5
+ - [X] New model, 🤗<a class="link" href="https://huggingface.co/kaist-ai/mistral-orpo-capybara-7k">kaist-ai/mistral-orpo-capybara-7k</a>, is added to 🤗<a class="link" href="https://huggingface.co/collections/kaist-ai/orpo-65efef87544ba100aef30013">ORPO Collection</a>
6
+ - [X] Now you can try ORPO in 🤗<a class="link" href="https://github.com/huggingface/trl">TRL</a> and <a class="link" href="https://github.com/OpenAccess-AI-Collective/axolotl">Axolotl</a>🔥
7
+ - [X] We are making general guideline for training LLMs with ORPO, stay tuned🔥
8
+ - [X] **Mistral-ORPO-β** achieved a 14.7% in the length-controlled (LC) win rate on <a class="link" href="https://tatsu-lab.github.io/alpaca_eval/">official AlpacaEval Leaderboard</a>🔥
9
+
10
+ &nbsp;
11
+
12
+ This is the official repository for <a class="link" href="https://arxiv.org/abs/2403.07691">**ORPO: Monolithic Preference Optimization without Reference Model**</a>. The detailed results in the paper can be found in:
13
+ - [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=kaist-ai%2Fmistral-orpo-beta)
14
+ - [AlpacaEval](#alpacaeval)
15
+ - [MT-Bench](#mt-bench)
16
+ - [IFEval](#ifeval)
17
+
18
+
19
+ ### **`Model Checkpoints`**
20
+
21
+ Our models trained with ORPO can be found in:
22
+
23
+ - [X] **Mistral-ORPO-Capybara-7k**: 🤗 <a class="link" href="https://huggingface.co/kaist-ai/mistral-orpo-capybara-7k">kaist-ai/mistral-orpo-capybara-7k</a>
24
+ - [X] **Mistral-ORPO-⍺**: 🤗 <a class="link" href="https://huggingface.co/kaist-ai/mistral-orpo-alpha">kaist-ai/mistral-orpo-alpha</a>
25
+ - [X] **Mistral-ORPO-β**: 🤗 <a class="link" href="https://huggingface.co/kaist-ai/mistral-orpo-beta">kaist-ai/mistral-orpo-beta</a>
26
+
27
+ And the corresponding logs for the average log probabilities of chosen/rejected responses during training are reported in:
28
+
29
+ - [X] **Mistral-ORPO-Capybara-7k**: TBU
30
+ - [X] **Mistral-ORPO-⍺**: <a class="link" href="https://wandb.ai/jiwooya1000/PREF/reports/Mistral-ORPO-7B-Training-Log--Vmlldzo3MTE1NzE0?accessToken=rms6o4mg5vo3feu1bvbpk632m4cspe19l0u1p4he3othx5bgean82chn9neiile6">Wandb Report for Mistral-ORPO-⍺</a>
31
+ - [X] **Mistral-ORPO-β**: <a class="link" href="https://wandb.ai/jiwooya1000/PREF/reports/Mistral-ORPO-7B-Training-Log--Vmlldzo3MTE3MzMy?accessToken=dij4qbp6dcrofsanzbgobjsne9el8a2zkly2u5z82rxisd4wiwv1rhp0s2dub11e">Wandb Report for Mistral-ORPO-β</a>
32
+
33
+ &nbsp;
34
+
35
+ ### **`AlpacaEval`**
36
+
37
+ <figure>
38
+ <img class="png" src="/assets/img/alpaca_blog.png" alt="Description of the image">
39
+ <figcaption><b>Figure 1.</b> AlpacaEval 2.0 score for the models trained with different alignment methods.</figcaption>
40
+ </figure>
41
+
42
+ &nbsp;
43
+
44
+ ### **`MT-Bench`**
45
+
46
+ <figure>
47
+ <img class="png" src="/assets/img/mtbench_hf.png" alt="Description of the image">
48
+ <figcaption><b>Figure 2.</b> MT-Bench result by category.</figcaption>
49
+ </figure>
50
+
51
+ &nbsp;
52
+
53
+ ### **`IFEval`**
54
+
55
+ IFEval scores are measured with <a class="link" href="https://github.com/EleutherAI/lm-evaluation-harness">EleutherAI/lm-evaluation-harness</a> by applying the chat template. The scores for Llama-2-Chat (70B), Zephyr-β (7B), and Mixtral-8X7B-Instruct-v0.1 are originally reported in <a class="link" href="https://twitter.com/wiskojo/status/1739767758462877823">this tweet</a>.
56
+
57
+ | **Model Type** | **Prompt-Strict** | **Prompt-Loose** | **Inst-Strict** | **Inst-Loose** |
58
+ |--------------------|:-----------------:|:----------------:|:---------------:|----------------|
59
+ | **Llama-2-Chat (70B)** | 0.4436 | 0.5342 | 0.5468 | 0.6319 |
60
+ | **Zephyr-β (7B)** | 0.4233 | 0.4547 | 0.5492 | 0.5767 |
61
+ | **Mixtral-8X7B-Instruct-v0.1** | 0.5213 | **0.5712** | 0.6343 | **0.6823** |
62
+ | **Mistral-ORPO-⍺ (7B)** | 0.5009 | 0.5083 | 0.5995 | 0.6163 |
63
+ | **Mistral-ORPO-β (7B)** | **0.5287** | 0.5564 | **0.6355** | 0.6619 |
assets/img/alpaca_blog.png ADDED
assets/img/mtbench_hf.png ADDED
main.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import time
4
+ import wandb
5
+ import torch
6
+ import argparse
7
+ from datasets import load_dataset
8
+ from typing import List, Dict, Union
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ AutoModelForCausalLM,
12
+ TrainingArguments,
13
+ DataCollatorForLanguageModeling
14
+ )
15
+
16
+ from src.args import default_args
17
+ from src.orpo_trainer import ORPOTrainer
18
+ from src.utils import preprocess_logits_for_metrics, dataset_split_selector
19
+
20
+ class ORPO(object):
21
+ def __init__(self, args) -> None:
22
+ self.start = time.gmtime()
23
+ self.args = args
24
+
25
+ # Load Tokenizer
26
+ print(">>> 1. Loading Tokenizer")
27
+ self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, cache_dir=self.args.cache_dir)
28
+ if self.tokenizer.chat_template is None:
29
+ self.tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
30
+ print(" 1-1. Chat Template Applied (<|user|> <|assistant|>)")
31
+ else:
32
+ pass
33
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
34
+
35
+ # Load Model
36
+ print(">>> 2. Loading Model")
37
+ if self.args.flash_attention_2:
38
+ self.model = AutoModelForCausalLM.from_pretrained(self.args.model_name,
39
+ cache_dir=self.args.cache_dir,
40
+ torch_dtype=torch.bfloat16,
41
+ attn_implementation="flash_attention_2")
42
+ else:
43
+ self.model = AutoModelForCausalLM.from_pretrained(self.args.model_name,
44
+ cache_dir=self.args.cache_dir,
45
+ torch_dtype=torch.bfloat16)
46
+
47
+ # Load Dataset
48
+ print(">>> 3. Loading Dataset")
49
+ self.data = load_dataset(self.args.data_name, cache_dir=self.args.cache_dir)
50
+
51
+ # Preprocess Dataset
52
+ print(">>> 4. Filtering and Preprocessing Dataset")
53
+ data_split = dataset_split_selector(self.data)
54
+
55
+ if len(data_split) == 1:
56
+ self.is_test = False
57
+ train_split = data_split[0]
58
+ print(f" >>> Test Set = {self.is_test}")
59
+ else:
60
+ self.is_test = True
61
+ train_split = data_split[0]
62
+ test_split = data_split[1]
63
+
64
+ test = self.data[test_split].filter(self.filter_dataset)
65
+ self.test = test.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[test_split].column_names)
66
+
67
+ train = self.data[train_split].filter(self.filter_dataset).select(range(self.args.max_samples))
68
+ print(f"\n\n>>> {len(train)} / {len(self.data[train_split])} rows left after filtering by prompt length.")
69
+ self.train = train.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[train_split].column_names)
70
+
71
+ # Set WANDB & Logging Configurations
72
+ self.run_name = f"{self.args.model_name.split('/')[-1]}-{self.args.data_name.split('/')[-1]}-lambda{self.args.alpha}-ORPO-{self.start.tm_mday}-{self.start.tm_hour}-{self.start.tm_min}"
73
+ self.save_dir = os.path.join('./checkpoints/', f"{self.args.data_name.split('/')[-1]}/{self.run_name}")
74
+ self.log_dir = os.path.join('./checkpoints/', f"{self.args.data_name.split('/')[-1]}/{self.run_name}/logs")
75
+
76
+ os.makedirs(self.save_dir, exist_ok=True)
77
+ os.makedirs(self.log_dir, exist_ok=True)
78
+
79
+ def preprocess_dataset(self, examples: Union[List, Dict]):
80
+ if ('instruction' in examples.keys()) or ('question' in examples.keys()):
81
+ prompt_key = 'instruction' if 'instruction' in examples.keys() else 'question'
82
+ prompt = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item}], tokenize=False, add_generation_prompt=True) for item in examples[prompt_key]]
83
+ chosen = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}, {'role': 'assistant', 'content': item_chosen}], tokenize=False) for item_prompt, item_chosen in zip(examples[prompt_key], examples['chosen'])]
84
+ rejected = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}, {'role': 'assistant', 'content': item_rejected}], tokenize=False) for item_prompt, item_rejected in zip(examples[prompt_key], examples['rejected'])]
85
+ else:
86
+ prompt = [self.tokenizer.apply_chat_template([item[0]], tokenize=False, add_generation_prompt=True) for item in examples['chosen']]
87
+ chosen = [self.tokenizer.apply_chat_template(item, tokenize=False) for item in examples['chosen']]
88
+ rejected = [self.tokenizer.apply_chat_template(item, tokenize=False) for item in examples['rejected']]
89
+
90
+ model_inputs = self.tokenizer(prompt,
91
+ max_length=self.args.response_max_length,
92
+ padding='max_length',
93
+ truncation=True,
94
+ return_tensors='pt')
95
+ pos_labels = self.tokenizer(chosen,
96
+ max_length=self.args.response_max_length,
97
+ padding='max_length',
98
+ truncation=True,
99
+ return_tensors='pt')
100
+ neg_labels = self.tokenizer(rejected,
101
+ max_length=self.args.response_max_length,
102
+ padding='max_length',
103
+ truncation=True,
104
+ return_tensors='pt')
105
+
106
+ model_inputs['positive_input_ids'] = pos_labels['input_ids']
107
+ model_inputs['positive_attention_mask'] = pos_labels['attention_mask']
108
+
109
+ model_inputs['negative_input_ids'] = neg_labels['input_ids']
110
+ model_inputs['negative_attention_mask'] = neg_labels['attention_mask']
111
+
112
+ return model_inputs
113
+
114
+ def filter_dataset(self, examples: Union[List, Dict]):
115
+ if 'instruction' in examples.keys():
116
+ query = examples['instruction']
117
+ prompt_length = self.tokenizer.apply_chat_template([{'content': query, 'role': 'user'}], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
118
+ elif 'question' in examples.keys():
119
+ query = examples['question']
120
+ prompt_length = self.tokenizer.apply_chat_template([{'content': query, 'role': 'user'}], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
121
+ else:
122
+ prompt_length = self.tokenizer.apply_chat_template([examples['chosen'][0]], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
123
+
124
+ if prompt_length < self.args.prompt_max_length:
125
+ return True
126
+ else:
127
+ return False
128
+
129
+ def prepare_trainer(self):
130
+ wandb.init(name=self.run_name)
131
+ arguments = TrainingArguments(
132
+ output_dir=self.save_dir, # The output directory
133
+ logging_dir=self.log_dir,
134
+ logging_steps=50,
135
+ learning_rate=self.args.lr,
136
+ overwrite_output_dir=True, # overwrite the content of the output directory
137
+ num_train_epochs=self.args.num_train_epochs, # number of training epochs
138
+ per_device_train_batch_size=self.args.per_device_train_batch_size, # batch size for training
139
+ per_device_eval_batch_size=self.args.per_device_eval_batch_size, # batch size for evaluation
140
+ evaluation_strategy=self.args.evaluation_strategy if self.is_test else 'no', # batch size for evaluation
141
+ save_strategy=self.args.evaluation_strategy,
142
+ optim=self.args.optim,
143
+ warmup_steps=self.args.warmup_steps,
144
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
145
+ gradient_checkpointing=True, #if ('llama' in self.args.model_name.lower()) or ('mistral' in self.args.model_name.lower()) else False,
146
+ gradient_checkpointing_kwargs={'use_reentrant':True},
147
+ load_best_model_at_end=self.is_test,
148
+ do_train=True,
149
+ do_eval=self.is_test,
150
+ lr_scheduler_type=self.args.lr_scheduler_type,
151
+ remove_unused_columns=False,
152
+ report_to='wandb',
153
+ run_name=self.run_name,
154
+ bf16=True
155
+ )
156
+
157
+ data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)
158
+
159
+ self.trainer = ORPOTrainer(
160
+ model=self.model,
161
+ alpha=self.args.alpha,
162
+ pad=self.tokenizer.pad_token_id,
163
+ args=arguments,
164
+ train_dataset=self.train,
165
+ eval_dataset=self.test if self.is_test else None,
166
+ data_collator=data_collator,
167
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics
168
+ )
169
+
170
+ def run(self):
171
+ print(">>> 5. Preparing ORPOTrainer")
172
+ self.prepare_trainer()
173
+ self.trainer.train()
174
+
175
+ # Saving code for FSDP
176
+ if self.trainer.is_fsdp_enabled:
177
+ self.trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
178
+ self.trainer.save_model()
179
+
180
+
181
+ if __name__ == '__main__':
182
+ parser = argparse.ArgumentParser("ORPO")
183
+ args = default_args(parser)
184
+
185
+ # Set WANDB configurations
186
+ if args.wandb_entity is not None and args.wandb_project_name is not None:
187
+ os.environ["WANDB_ENTITY"] = args.wandb_entity
188
+ os.environ["WANDB_PROJECT"] = args.wandb_project_name
189
+ else:
190
+ pass
191
+ os.environ["TOKENIZERS_PARALLELISM"] = 'false'
192
+
193
+ print("================================================================================================\n")
194
+ print(f">>> Fine-tuning {args.model_name} with ORPO on {args.data_name}\n")
195
+ print("================================================================================================")
196
+ print("\n\n>>> Summary:")
197
+ print(f" - Lambda : {args.alpha}")
198
+ print(f" - Training Epochs : {args.num_train_epochs}")
199
+ print(f" - Prompt Max Length : {args.prompt_max_length}")
200
+ print(f" - Response Max Length : {args.response_max_length}")
201
+
202
+ item = ORPO(args=args)
203
+ item.run()
outputs/alpacaeval/Mistral-ORPO-alpha.json ADDED
The diff for this file is too large to render. See raw diff
 
outputs/alpacaeval/Mistral-ORPO-beta.json ADDED
The diff for this file is too large to render. See raw diff
 
outputs/mtbench/Mistral-ORPO-alpha.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
outputs/mtbench/Mistral-ORPO-beta.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate @ file:///home/conda/feedstock_root/build_artifacts/accelerate_1710334587919/work
2
+ aiohttp @ file:///croot/aiohttp_1707342283163/work
3
+ aiosignal @ file:///tmp/build/80754af9/aiosignal_1637843061372/work
4
+ appdirs==1.4.4
5
+ asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
6
+ attrs @ file:///croot/attrs_1695717823297/work
7
+ bitsandbytes==0.43.0
8
+ Bottleneck @ file:///croot/bottleneck_1707864210935/work
9
+ Brotli @ file:///work/ci_py311/brotli-split_1676830125088/work
10
+ cachetools==5.3.3
11
+ certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1707022139797/work/certifi
12
+ cffi @ file:///croot/cffi_1700254295673/work
13
+ charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
14
+ click @ file:///croot/click_1698129812380/work
15
+ comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work
16
+ datasets @ file:///home/conda/feedstock_root/build_artifacts/datasets_1709395865330/work
17
+ debugpy @ file:///croot/debugpy_1690905042057/work
18
+ decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
19
+ dill @ file:///croot/dill_1692271232022/work
20
+ docker-pycreds @ file:///Users/ktietz/demo/mc3/conda-bld/docker-pycreds_1630654474270/work
21
+ einops==0.7.0
22
+ exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1704921103267/work
23
+ executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
24
+ filelock @ file:///croot/filelock_1700591183607/work
25
+ flash-attn==2.5.6
26
+ frozenlist @ file:///croot/frozenlist_1698702560391/work
27
+ fsspec==2023.4.0
28
+ gitdb @ file:///tmp/build/80754af9/gitdb_1617117951232/work
29
+ GitPython @ file:///croot/gitpython_1696936983078/work
30
+ gmpy2 @ file:///work/ci_py311/gmpy2_1676839849213/work
31
+ huggingface-hub @ file:///croot/huggingface_hub_1708634519519/work
32
+ idna @ file:///work/ci_py311/idna_1676822698822/work
33
+ importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1709821103657/work
34
+ ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1708996548741/work
35
+ ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1709559745751/work
36
+ jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
37
+ Jinja2==3.1.2
38
+ jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1710255804825/work
39
+ jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257359434/work
40
+ MarkupSafe @ file:///croot/markupsafe_1704205993651/work
41
+ matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
42
+ mkl-fft @ file:///croot/mkl_fft_1695058164594/work
43
+ mkl-random @ file:///croot/mkl_random_1695059800811/work
44
+ mkl-service==2.4.0
45
+ mpmath @ file:///croot/mpmath_1690848262763/work
46
+ multidict @ file:///croot/multidict_1701096859099/work
47
+ multiprocess @ file:///croot/multiprocess_1692294385131/work
48
+ nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
49
+ networkx==3.2.1
50
+ ninja==1.11.1.1
51
+ numexpr @ file:///croot/numexpr_1696515281613/work
52
+ numpy @ file:///croot/numpy_and_numpy_base_1708638617955/work/dist/numpy-1.26.4-cp311-cp311-linux_x86_64.whl#sha256=5f96f274d410a1682519282ae769c877d32fdbf171aa8badec7bf5e1d3a1748a
53
+ nvidia-cublas-cu11==11.11.3.6
54
+ nvidia-cuda-cupti-cu11==11.8.87
55
+ nvidia-cuda-nvrtc-cu11==11.8.89
56
+ nvidia-cuda-runtime-cu11==11.8.89
57
+ nvidia-cudnn-cu11==8.7.0.84
58
+ nvidia-cufft-cu11==10.9.0.58
59
+ nvidia-curand-cu11==10.3.0.86
60
+ nvidia-cusolver-cu11==11.4.1.48
61
+ nvidia-cusparse-cu11==11.7.5.86
62
+ nvidia-ml-py==12.535.133
63
+ nvidia-nccl-cu11==2.19.3
64
+ nvidia-nvtx-cu11==11.8.86
65
+ nvitop==1.3.2
66
+ packaging @ file:///croot/packaging_1693575174725/work
67
+ pandas @ file:///croot/pandas_1709590491089/work/dist/pandas-2.2.1-cp311-cp311-linux_x86_64.whl#sha256=0a2793a31a0135a35735e1431d453a06186a3a7c607d9b441d9bd5f0fe4ded31
68
+ parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
69
+ pathtools @ file:///Users/ktietz/demo/mc3/conda-bld/pathtools_1629713893697/work
70
+ pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
71
+ pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
72
+ pillow==10.2.0
73
+ platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1706713388748/work
74
+ prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1702399386289/work
75
+ protobuf==3.20.3
76
+ psutil @ file:///work/ci_py311_2/psutil_1679337388738/work
77
+ ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
78
+ pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
79
+ pyarrow @ file:///croot/pyarrow_1707330824290/work/python
80
+ pyarrow-hotfix @ file:///home/conda/feedstock_root/build_artifacts/pyarrow-hotfix_1700596371886/work
81
+ pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
82
+ Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1700607939962/work
83
+ PySocks @ file:///work/ci_py311/pysocks_1676822712504/work
84
+ python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
85
+ pytz @ file:///croot/pytz_1695131579487/work
86
+ PyYAML @ file:///croot/pyyaml_1698096049011/work
87
+ pyzmq @ file:///croot/pyzmq_1705605076900/work
88
+ regex @ file:///croot/regex_1696515298636/work
89
+ requests @ file:///croot/requests_1707355572290/work
90
+ safetensors @ file:///croot/safetensors_1708633833937/work
91
+ sentry-sdk @ file:///work/ci_py311/sentry-sdk_1676862120883/work
92
+ setproctitle @ file:///work/ci_py311/setproctitle_1676838789127/work
93
+ six @ file:///tmp/build/80754af9/six_1644875935023/work
94
+ smmap @ file:///tmp/build/80754af9/smmap_1611694433573/work
95
+ stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
96
+ sympy @ file:///croot/sympy_1701397643339/work
97
+ termcolor==2.4.0
98
+ tokenizers @ file:///croot/tokenizers_1708633814160/work
99
+ torch==2.2.1+cu118
100
+ torchaudio==2.2.1+cu118
101
+ torchvision==0.17.1+cu118
102
+ tornado @ file:///croot/tornado_1696936946304/work
103
+ tqdm @ file:///croot/tqdm_1679561862951/work
104
+ traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1710254411456/work
105
+ transformers @ file:///home/conda/feedstock_root/build_artifacts/transformers_1709308155748/work
106
+ triton==2.2.0
107
+ typing_extensions==4.8.0
108
+ tzdata @ file:///croot/python-tzdata_1690578112552/work
109
+ urllib3 @ file:///croot/urllib3_1707770551213/work
110
+ wandb @ file:///home/conda/feedstock_root/build_artifacts/wandb_1707246480133/work
111
+ wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
112
+ xxhash @ file:///work/ci_py311/python-xxhash_1676842384694/work
113
+ yarl @ file:///croot/yarl_1701105127787/work
114
+ zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1695255097490/work
runpod.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pip install datasets accelerate wandb transformers bitsandbytes sentencepiece
2
+ git clone https://github.com/burtenshaw/orpo.git
3
+ cd orpo
4
+ sed -i 's/num_processes: 2/num_processes: 1/' ./src/accelerate/fsdp.yaml
5
+ sed -i 's/--num_proc", default=8/--num_proc", default=1/' ./src/args.py
6
+ wandb login $WANDB_TOKEN
7
+ wandb init -p $WANDB_PROJECT
8
+ accelerate launch --config_file ./src/accelerate/fsdp.yaml main.py \
9
+ --lr $LEARNING_RATE \
10
+ --warmup_steps 100 \
11
+ --model_name $MODEL_ID \
12
+ --data_name $DATASET \
13
+ --num_train_epochs $EPOCH \
14
+ --max_samples $MAX_SAMPLES \
15
+ --prompt_max_length 128 \
16
+ --response_max_length 2048 \
17
+ --per_device_train_batch_size 4 \
18
+ --per_device_eval_batch_size 4 \
19
+ --gradient_accumulation_steps 1 \
20
+ --num_proc 1
21
+ cd $OUTPUT
22
+ cd */
23
+ huggingface-cli login --token $TOKEN
24
+ huggingface-cli upload $NEW_MODEL . .
scripts/run_mistral_orpo_beta.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Mistral-ORPO series are trained on 4 * A100s
4
+
5
+ accelerate launch --config_file ./src/accelerate/fsdp.yaml main.py \
6
+ --lr 5e-6 \
7
+ --lr_scheduler_type inverse_sqrt \
8
+ --alpha 0.1 \
9
+ --torch_compile False \
10
+ --warmup_steps 200 \
11
+ --model_name mistralai/Mistral-7B-v0.1 \
12
+ --data_name argilla/ultrafeedback-binarized-preferences-cleaned \
13
+ --num_train_epochs 5 \
14
+ --prompt_max_length 1792 \
15
+ --response_max_length 2048 \
16
+ --per_device_train_batch_size 8 \
17
+ --per_device_eval_batch_size 8 \
18
+ --gradient_accumulation_steps 1 \
19
+ --num_proc 8 \
20
+ --flash_attention_2
scripts/run_mistral_orpo_capybara.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Mistral-ORPO series are trained on 4 * A100s
4
+
5
+ accelerate launch --config_file ./src/accelerate/fsdp.yaml main.py \
6
+ --lr 5e-6 \
7
+ --torch_compile False \
8
+ --alpha 0.05 \
9
+ --lr_scheduler_type inverse_sqrt \
10
+ --cache_dir /projects/hf_cache/ \
11
+ --warmup_steps 100 \
12
+ --model_name mistralai/Mistral-7B-v0.1 \
13
+ --data_name argilla/distilabel-capybara-dpo-7k-binarized \
14
+ --num_train_epochs 3 \
15
+ --optim adamw_bnb_8bit \
16
+ --gradient_accumulation_steps 1 \
17
+ --prompt_max_length 1792 \
18
+ --response_max_length 2048 \
19
+ --per_device_train_batch_size 8 \
20
+ --per_device_eval_batch_size 8 \
21
+ --num_proc 8 \
22
+ --flash_attention_2
src/accelerate/ds2.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 1
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: false
8
+ zero_stage: 2
9
+ distributed_type: DEEPSPEED
10
+ downcast_bf16: 'no'
11
+ machine_rank: 0
12
+ main_training_function: main
13
+ mixed_precision: bf16
14
+ num_machines: 1
15
+ num_processes: 2
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
src/args.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def default_args(parser):
2
+ parser.add_argument("--cache_dir", default=None, type=str)
3
+ parser.add_argument("--save_dir", default='./saved', type=str)
4
+ parser.add_argument("--data_name", default='HuggingfaceH4/UltraFeedback', type=str)
5
+ parser.add_argument("--model_name", default="gpt2", type=str)
6
+
7
+ # Training Arguments
8
+ parser.add_argument("--torch_compile", default=False, type=bool)
9
+ parser.add_argument("--flash_attention_2", action='store_true')
10
+ parser.add_argument("--lr_scheduler_type", default="cosine", type=str)
11
+ parser.add_argument("--optim", default="paged_adamw_32bit", type=str)
12
+ parser.add_argument("--overwrite_output_dir", default=True, type=bool)
13
+ parser.add_argument("--lr", default=2e-5, type=float)
14
+ parser.add_argument("--num_proc", default=1, type=int)
15
+ parser.add_argument("--num_train_epochs", default=10, type=int)
16
+ parser.add_argument("--per_device_train_batch_size", default=2, type=int)
17
+ parser.add_argument("--per_device_eval_batch_size", default=2, type=int)
18
+ parser.add_argument("--warmup_steps", default=5000, type=int)
19
+ parser.add_argument("--evaluation_strategy", default='epoch', type=str)
20
+ parser.add_argument("--do_eval", action='store_true')
21
+ parser.add_argument("--gradient_accumulation_steps", default=1, type=int)
22
+ parser.add_argument("--save_strategy", default='epoch', type=str)
23
+ parser.add_argument("--prompt_max_length", default=256, type=int)
24
+ parser.add_argument("--response_max_length", default=1024, type=int)
25
+ parser.add_argument("--alpha", default=1.0, type=float, help="Hyperparameter for weighting L_OR")
26
+
27
+ # Wandb Configurations
28
+ parser.add_argument("--wandb_entity", default=None, type=str)
29
+ parser.add_argument("--wandb_project_name", default=None, type=str)
30
+
31
+
32
+ args = parser.parse_args()
33
+
34
+ return args
src/orpo_trainer.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import wandb
4
+ from transformers import Trainer
5
+
6
+
7
+ class ORPOTrainer(Trainer):
8
+ def __init__(self, alpha, pad, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.pad = pad
11
+ self.alpha = alpha
12
+ self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
13
+ print("Pad Token ID: ", self.pad)
14
+
15
+ def compute_custom_loss(self, logits, labels):
16
+
17
+ logits = logits.contiguous()
18
+
19
+ if labels is not None:
20
+ # move labels to correct device to enable model parallelism
21
+ labels = labels.to(logits.device)
22
+ # Shift so that tokens < n predict n
23
+ shift_logits = logits[..., :-1, :].contiguous()
24
+ shift_labels = labels[..., 1:].contiguous()
25
+
26
+ # Flatten the tokens
27
+ loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(dim=-1)
28
+
29
+ return loss
30
+
31
+ def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
32
+ mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]
33
+ per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2,
34
+ index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2)
35
+ return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64)
36
+
37
+ def compute_loss(self, model, inputs, return_outputs=False):
38
+ if self.label_smoother is not None and "labels" in inputs:
39
+ labels = inputs.pop("labels")
40
+ else:
41
+ labels = None
42
+
43
+ # Generate the hidden states for 'chosen' and 'reject'
44
+ neg_labels = inputs['negative_input_ids'].clone()
45
+ pos_labels = inputs['positive_input_ids'].clone()
46
+
47
+ neg_labels[neg_labels == self.pad] = -100
48
+ pos_labels[pos_labels == self.pad] = -100
49
+
50
+ outputs_neg = model(**{'input_ids': inputs['negative_input_ids'],
51
+ 'attention_mask': inputs['negative_attention_mask'],
52
+ 'labels': neg_labels,}, output_hidden_states=True)
53
+ outputs_pos = model(**{'input_ids': inputs['positive_input_ids'],
54
+ 'attention_mask': inputs['positive_attention_mask'],
55
+ 'labels': pos_labels,}, output_hidden_states=True)
56
+
57
+ # Calculate NLL loss
58
+ pos_loss = self.compute_custom_loss(logits=outputs_pos.logits, labels=inputs['positive_input_ids'])
59
+
60
+ # Calculate Log Probability
61
+ pos_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'],
62
+ chosen_inputs=inputs['positive_input_ids'],
63
+ chosen_attention_mask=inputs['positive_attention_mask'],
64
+ logits=outputs_pos.logits)
65
+ neg_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'],
66
+ chosen_inputs=inputs['negative_input_ids'],
67
+ chosen_attention_mask=inputs['negative_attention_mask'],
68
+ logits=outputs_neg.logits)
69
+
70
+ # Calculate log odds
71
+ log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)))
72
+ sig_ratio = torch.nn.functional.sigmoid(log_odds)
73
+ ratio = torch.log(sig_ratio)
74
+
75
+ # Calculate the Final Loss
76
+ loss = torch.mean(pos_loss - self.alpha * ratio).to(dtype=torch.bfloat16)
77
+
78
+ wandb.log({'Positive Geometric Mean': torch.mean(pos_prob).item(),
79
+ 'Negative Geometric Mean': torch.mean(neg_prob).item(),
80
+ 'Log Odds Ratio': torch.mean(ratio).item(),
81
+ 'Log Odds': torch.mean(log_odds).item()})
82
+
83
+ return (loss, outputs_pos) if return_outputs else loss
src/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List
3
+
4
+ def preprocess_logits_for_metrics(logits, labels):
5
+ if isinstance(logits, tuple):
6
+ logits = logits[0]
7
+ return logits.argmax(dim=-1)
8
+
9
+ def dataset_split_selector(data) -> List:
10
+ """
11
+ This is a function for automating the process of selecting data split.
12
+ Will be further updated.
13
+ """
14
+ if len(data.keys()) == 1:
15
+ return ['train']
16
+ else:
17
+ if 'train_prefs' in data.keys():
18
+ return ['train_prefs', 'test_prefs']
19
+ else:
20
+ return ['train', 'test']
trl/test_orpo_trainer_demo.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+
5
+ import os
6
+ import torch
7
+ from datasets import load_dataset
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, HfArgumentParser, pipeline
10
+
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from trl import ORPOConfig, ORPOTrainer, set_seed
13
+ from trl.core import LengthSampler
14
+
15
+ # This code is built on top of the example code from Huggingface TRL Team
16
+
17
+ tqdm.pandas()
18
+
19
+ @dataclass
20
+ class ScriptArguments:
21
+ model_name: Optional[str] = field(default="microsoft/phi-2", metadata={"help": "the model name"})
22
+ optim: Optional[str] = field(default="adamw_torch", metadata={"help": "the model name"})
23
+ data_name: Optional[str] = field(default="argilla/ultrafeedback-binarized-preferences-cleaned", metadata={"help": "the model name"})
24
+ cache_dir: Optional[str] = field(default="", metadata={"help": "the model name"})
25
+ log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
26
+ output_dir: Optional[str] = field(default='', metadata={"help": "use 'wandb' to log with wandb"})
27
+ learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
28
+ lr_scheduler_type: Optional[str] = field(default='cosine', metadata={"help": "the learning rate scheduler"})
29
+ per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "the batch size"})
30
+ num_train_epochs: Optional[int] = field(default=5, metadata={"help": "the batch size"})
31
+ beta: Optional[float] = field(default=0.25, metadata={"help": "weighting hyperparameter for L_OR"})
32
+ gradient_accumulation_steps: Optional[int] = field(
33
+ default=1, metadata={"help": "the number of gradient accumulation steps"}
34
+ )
35
+
36
+
37
+ parser = HfArgumentParser(ScriptArguments)
38
+ script_args = parser.parse_args_into_dataclasses()[0]
39
+
40
+ config = ORPOConfig(
41
+ output_dir=script_args.output_dir,
42
+ max_prompt_length=1024,
43
+ max_length=2048,
44
+ logging_steps=100,
45
+ save_strategy='no',
46
+ max_completion_length=2048,
47
+ per_device_train_batch_size=script_args.per_device_train_batch_size,
48
+ remove_unused_columns=False,
49
+ gradient_accumulation_steps=script_args.gradient_accumulation_steps,
50
+ learning_rate=script_args.learning_rate,
51
+ optim=script_args.optim,
52
+ lr_scheduler_type=script_args.lr_scheduler_type,
53
+ gradient_checkpointing=True,
54
+ gradient_checkpointing_kwargs={'use_reentrant':True},
55
+ beta=script_args.beta,
56
+ report_to='wandb',
57
+ num_train_epochs=script_args.num_train_epochs,
58
+ bf16=True,
59
+ do_eval=False
60
+ )
61
+
62
+ model = AutoModelForCausalLM.from_pretrained(script_args.model_name,
63
+ cache_dir=script_args.cache_dir,
64
+ attn_implementation='flash_attention_2',
65
+ torch_dtype=torch.bfloat16)
66
+ tokenizer = AutoTokenizer.from_pretrained(script_args.model_name,
67
+ cache_dir=script_args.cache_dir)
68
+ tokenizer.pad_token_id = tokenizer.eos_token_id
69
+ tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
70
+
71
+ def build_dataset(tokenizer):
72
+ ds_train = load_dataset(script_args.data_name, split="train",
73
+ cache_dir=script_args.cache_dir)
74
+
75
+ def chat_template_to_text(sample):
76
+ sample["chosen"] = [item_chosen[1]['content'] for item_chosen in sample['chosen']]
77
+ sample["rejected"] = [item_rejected[1]['content'] for item_rejected in sample['rejected']]
78
+ sample['prompt'] = [tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}], tokenize=False, add_generation_prompt=True) for item_prompt in sample['prompt']]
79
+
80
+ return sample
81
+
82
+ ds_train = ds_train.map(chat_template_to_text, batched=True, num_proc=8)
83
+
84
+ return ds_train
85
+
86
+ train = build_dataset(tokenizer=tokenizer)
87
+
88
+ trainer = ORPOTrainer(
89
+ model=model,
90
+ args=config,
91
+ tokenizer=tokenizer,
92
+ train_dataset=train
93
+ )
94
+
95
+ trainer.train()