milandean commited on
Commit
2c91676
1 Parent(s): 90aa07a

Upload example_sft_qlora.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example_sft_qlora.py +146 -0
example_sft_qlora.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ from transformers import AutoTokenizer, HfArgumentParser, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
7
+ from datasets import load_dataset
8
+ from peft import LoraConfig
9
+ from trl import SFTTrainer
10
+
11
+ @dataclass
12
+ class ScriptArguments:
13
+ """
14
+ These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
15
+ """
16
+ per_device_train_batch_size: Optional[int] = field(default=4)
17
+ per_device_eval_batch_size: Optional[int] = field(default=1)
18
+ gradient_accumulation_steps: Optional[int] = field(default=4)
19
+ learning_rate: Optional[float] = field(default=2e-4)
20
+ max_grad_norm: Optional[float] = field(default=0.3)
21
+ weight_decay: Optional[int] = field(default=0.001)
22
+ lora_alpha: Optional[int] = field(default=16)
23
+ lora_dropout: Optional[float] = field(default=0.1)
24
+ lora_r: Optional[int] = field(default=8)
25
+ max_seq_length: Optional[int] = field(default=2048)
26
+ model_name: Optional[str] = field(
27
+ default=None,
28
+ metadata={
29
+ "help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
30
+ }
31
+ )
32
+ dataset_name: Optional[str] = field(
33
+ default="stingning/ultrachat",
34
+ metadata={"help": "The preference dataset to use."},
35
+ )
36
+ fp16: Optional[bool] = field(
37
+ default=False,
38
+ metadata={"help": "Enables fp16 training."},
39
+ )
40
+ bf16: Optional[bool] = field(
41
+ default=False,
42
+ metadata={"help": "Enables bf16 training."},
43
+ )
44
+ packing: Optional[bool] = field(
45
+ default=True,
46
+ metadata={"help": "Use packing dataset creating."},
47
+ )
48
+ gradient_checkpointing: Optional[bool] = field(
49
+ default=True,
50
+ metadata={"help": "Enables gradient checkpointing."},
51
+ )
52
+ use_flash_attention_2: Optional[bool] = field(
53
+ default=False,
54
+ metadata={"help": "Enables Flash Attention 2."},
55
+ )
56
+ optim: Optional[str] = field(
57
+ default="paged_adamw_32bit",
58
+ metadata={"help": "The optimizer to use."},
59
+ )
60
+ lr_scheduler_type: str = field(
61
+ default="constant",
62
+ metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"},
63
+ )
64
+ max_steps: int = field(default=1000, metadata={"help": "How many optimizer update steps to take"})
65
+ warmup_ratio: float = field(default=0.03, metadata={"help": "Fraction of steps to do a warmup for"})
66
+ save_steps: int = field(default=10, metadata={"help": "Save checkpoint every X updates steps."})
67
+ logging_steps: int = field(default=10, metadata={"help": "Log every X updates steps."})
68
+ output_dir: str = field(
69
+ default="./results",
70
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
71
+ )
72
+
73
+ parser = HfArgumentParser(ScriptArguments)
74
+ script_args = parser.parse_args_into_dataclasses()[0]
75
+
76
+
77
+ def formatting_func(example):
78
+ text = f"### USER: {example['data'][0]}\n### ASSISTANT: {example['data'][1]}"
79
+ return text
80
+
81
+ # Load the GG model - this is the local one, update it to the one on the Hub
82
+ model_id = "google/gemma-7b"
83
+
84
+ quantization_config = BitsAndBytesConfig(
85
+ load_in_4bit=True,
86
+ bnb_4bit_compute_dtype=torch.float16,
87
+ bnb_4bit_quant_type="nf4"
88
+ )
89
+
90
+ # Load model
91
+ model = AutoModelForCausalLM.from_pretrained(
92
+ model_id,
93
+ quantization_config=quantization_config,
94
+ torch_dtype=torch.float32,
95
+ attn_implementation="sdpa" if not script_args.use_flash_attention_2 else "flash_attention_2"
96
+ )
97
+
98
+ # Load tokenizer
99
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
100
+ tokenizer.pad_token_id = tokenizer.eos_token_id
101
+
102
+ lora_config = LoraConfig(
103
+ r=script_args.lora_r,
104
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
105
+ bias="none",
106
+ task_type="CAUSAL_LM",
107
+ lora_alpha=script_args.lora_alpha,
108
+ lora_dropout=script_args.lora_dropout
109
+ )
110
+
111
+ train_dataset = load_dataset(script_args.dataset_name, split="train[:5%]")
112
+
113
+ # TODO: make that configurable
114
+ YOUR_HF_USERNAME = xxx
115
+ output_dir = f"{YOUR_HF_USERNAME}/gemma-qlora-ultrachat"
116
+
117
+ training_arguments = TrainingArguments(
118
+ output_dir=output_dir,
119
+ per_device_train_batch_size=script_args.per_device_train_batch_size,
120
+ gradient_accumulation_steps=script_args.gradient_accumulation_steps,
121
+ optim=script_args.optim,
122
+ save_steps=script_args.save_steps,
123
+ logging_steps=script_args.logging_steps,
124
+ learning_rate=script_args.learning_rate,
125
+ max_grad_norm=script_args.max_grad_norm,
126
+ max_steps=script_args.max_steps,
127
+ warmup_ratio=script_args.warmup_ratio,
128
+ lr_scheduler_type=script_args.lr_scheduler_type,
129
+ gradient_checkpointing=script_args.gradient_checkpointing,
130
+ fp16=script_args.fp16,
131
+ bf16=script_args.bf16,
132
+ )
133
+
134
+ trainer = SFTTrainer(
135
+ model=model,
136
+ args=training_arguments,
137
+ train_dataset=train_dataset,
138
+ peft_config=lora_config,
139
+ packing=script_args.packing,
140
+ dataset_text_field="id",
141
+ tokenizer=tokenizer,
142
+ max_seq_length=script_args.max_seq_length,
143
+ formatting_func=formatting_func,
144
+ )
145
+
146
+ trainer.train()