winglian commited on
Commit
f2a2029
·
1 Parent(s): a6028d3

config chooser, update readme instructions, device config, llama flash attention, debug out the labels, fix config key checks, other bugfixes

Browse files
README.md CHANGED
@@ -2,8 +2,18 @@
2
 
3
  #### You know you're going to axolotl questions
4
 
 
5
 
6
- ### Converting JSON data files to JSONL
 
 
 
 
 
 
 
 
 
7
 
8
  ```shell
9
  python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
@@ -11,3 +21,13 @@ python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json >
11
  python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
12
  python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
13
  ```
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  #### You know you're going to axolotl questions
4
 
5
+ ## Getting Started
6
 
7
+ - Download some datasets.
8
+
9
+ ```shell
10
+ curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o data/raw/alpaca_data_gpt4.json
11
+ curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o data/raw/vicuna_cleaned.json
12
+ curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o data/raw/gpt4-instruct-similarity-0.6-dataset.json
13
+ curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o data/raw/roleplay-similarity_0.6-instruct-dataset.json
14
+ ```
15
+
16
+ - Convert the JSON data files to JSONL.
17
 
18
  ```shell
19
  python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
 
21
  python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
22
  python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
23
  ```
24
+
25
+ - Using JSONL makes it easier to subset the data if you want a smaller training set, i.e get 2000 random examples.
26
+
27
+ ```shell
28
+ shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
29
+ ```
30
+
31
+ - Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
32
+ - Install python dependencies `pip3 install -r requirements.txt`
33
+ - Train! `python3 scripts/finetune.py`, make sure to choose the correct YAML config file
configs/cerebras_1_3B_alpaca.yml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: cerebras/Cerebras-GPT-1.3B
2
+ model_type: AutoModelForCausalLM
3
+ tokenizer_type: AutoTokenizer
4
+ load_in_8bit: true
5
+ datasets:
6
+ - path: data/alpaca_data_gpt4.jsonl
7
+ type: alpaca
8
+ - path: data/vicuna_cleaned.jsonl
9
+ type: sharegpt
10
+ - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
11
+ type: gpteacher
12
+ - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
13
+ type: gpteacher
14
+ val_set_size: 0.05
15
+ adapter: lora
16
+ sequence_len: 2048
17
+ lora_r: 8
18
+ lora_alpha: 16
19
+ lora_dropout: 0.05
20
+ lora_target_modules:
21
+ - c_attn
22
+ lora_fan_in_fan_out: false
23
+ wandb_project: pythia-1.4b-lora
24
+ wandb_watch:
25
+ wandb_run_name:
26
+ wandb_log_model: checkpoint
27
+ output_dir: ./lora-alpaca
28
+ batch_size: 32
29
+ micro_batch_size: 4
30
+ num_epochs: 5
31
+ learning_rate: 0.0003
32
+ train_on_inputs: false
33
+ group_by_length: false
34
+ bf16: True
35
+ tf32: True
36
+ resume_from_checkpoint:
37
+ local_rank:
38
+ deepspeed:
requirements.txt CHANGED
@@ -4,3 +4,9 @@ attrdict
4
  fire
5
  PyYAML==6.0
6
  black
 
 
 
 
 
 
 
4
  fire
5
  PyYAML==6.0
6
  black
7
+ bitsandbytes
8
+ datasets
9
+ accelerate
10
+ sentencepiece
11
+ wandb
12
+ flash-attn
scripts/finetune.py CHANGED
@@ -9,7 +9,7 @@ import fire
9
  import torch
10
  import transformers
11
  import yaml
12
- from attrdict import AttrDict
13
  from datasets import load_dataset, IterableDataset, Dataset
14
  from peft import (
15
  LoraConfig,
@@ -50,6 +50,11 @@ def setup_wandb_env_vars(cfg):
50
  def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
51
  if adapter != "lora":
52
  raise NotImplementedError(f"{adapter} peft adapter not available")
 
 
 
 
 
53
  try:
54
  model = getattr(transformers, model_type).from_pretrained(
55
  base_model,
@@ -99,24 +104,104 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
99
  return model, tokenizer, lora_config
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def train(
103
- config: Path = Path("configs/pythia_1_2B_alpaca.yml"),
104
  **kwargs,
105
  ):
 
 
 
106
  # load the config from the yaml file
107
  with open(config, "r") as f:
108
- cfg: AttrDict = AttrDict(yaml.load(f, Loader=yaml.Loader))
109
  # if there are any options passed in the cli, if it is something that seems valid from the yaml,
110
  # then overwrite the value
111
- for k, v in enumerate(kwargs):
112
- if k in cfg:
113
- cfg.k = v
 
 
 
 
 
114
 
115
  # setup some derived config / hyperparams
116
  cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
117
- cfg.device_map = "auto"
118
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
119
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
 
120
  cfg.ddp = cfg.world_size != 1
121
  if cfg.ddp:
122
  cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
@@ -163,6 +248,8 @@ def train(
163
  train_dataset = constant_len_dataset["train"]
164
  eval_dataset = constant_len_dataset["test"]
165
 
 
 
166
  total_num_steps = int(
167
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
168
  )
@@ -240,6 +327,7 @@ def train(
240
  if torch.__version__ >= "2" and sys.platform != "win32":
241
  model = torch.compile(model)
242
 
 
243
  signal.signal(
244
  signal.SIGINT,
245
  lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
 
9
  import torch
10
  import transformers
11
  import yaml
12
+ from attrdict import AttrDefault
13
  from datasets import load_dataset, IterableDataset, Dataset
14
  from peft import (
15
  LoraConfig,
 
50
  def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
51
  if adapter != "lora":
52
  raise NotImplementedError(f"{adapter} peft adapter not available")
53
+ if "llama" in base_model:
54
+ from axolotl.flash_attn import replace_llama_attn_with_flash_attn
55
+
56
+ replace_llama_attn_with_flash_attn()
57
+
58
  try:
59
  model = getattr(transformers, model_type).from_pretrained(
60
  base_model,
 
104
  return model, tokenizer, lora_config
105
 
106
 
107
+ def choose_device(cfg):
108
+ def get_device():
109
+ if torch.cuda.is_available():
110
+ return "cuda"
111
+ else:
112
+ try:
113
+ if torch.backends.mps.is_available():
114
+ return "mps"
115
+ except:
116
+ return "cpu"
117
+
118
+ cfg.device = get_device()
119
+ if cfg.device == "cuda":
120
+ cfg.device_map = {"": cfg.local_rank}
121
+ else:
122
+ cfg.device_map = {"": cfg.device}
123
+
124
+
125
+ def check_dataset_labels(dataset, tokenizer):
126
+ from termcolor import colored
127
+
128
+ # the dataset is already shuffled, so let's just check the first 5 elements
129
+ for idx in range(5):
130
+ # Get the input_ids, labels, and attention_mask from the dataset
131
+ input_ids = dataset[idx]["input_ids"]
132
+ labels = dataset[idx]["labels"]
133
+ attention_mask = dataset[idx]["attention_mask"]
134
+
135
+ # You can compare the input_ids and labels element-wise
136
+ # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
137
+ colored_tokens = []
138
+ for i, (input_id, label_id, mask) in enumerate(
139
+ zip(input_ids, labels, attention_mask)
140
+ ):
141
+ decoded_input_token = tokenizer.decode(input_id)
142
+ # Choose the color based on whether the label has the ignore value or not
143
+ color = (
144
+ "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
145
+ )
146
+ colored_token = colored(decoded_input_token, color) + colored(
147
+ f"({label_id}, {mask})", "white"
148
+ )
149
+ colored_tokens.append(colored_token)
150
+
151
+ print(" ".join(colored_tokens))
152
+ print("\n\n\n")
153
+
154
+
155
+ def choose_config(path: Path):
156
+ yaml_files = [file for file in path.glob("*.yml")]
157
+
158
+ if not yaml_files:
159
+ raise ValueError("No YAML config files found in the specified directory. Are you using a .yml extension?")
160
+
161
+ print("Choose a YAML file:")
162
+ for idx, file in enumerate(yaml_files):
163
+ print(f"{idx + 1}. {file}")
164
+
165
+ chosen_file = None
166
+ while chosen_file is None:
167
+ try:
168
+ choice = int(input("Enter the number of your choice: "))
169
+ if 1 <= choice <= len(yaml_files):
170
+ chosen_file = yaml_files[choice - 1]
171
+ else:
172
+ print("Invalid choice. Please choose a number from the list.")
173
+ except ValueError:
174
+ print("Invalid input. Please enter a number.")
175
+
176
+ return chosen_file
177
+
178
+
179
  def train(
180
+ config: Path = Path("configs/"),
181
  **kwargs,
182
  ):
183
+ if config.is_dir():
184
+ config = choose_config(config)
185
+
186
  # load the config from the yaml file
187
  with open(config, "r") as f:
188
+ cfg: AttrDefault = AttrDefault(lambda: None, yaml.load(f, Loader=yaml.Loader))
189
  # if there are any options passed in the cli, if it is something that seems valid from the yaml,
190
  # then overwrite the value
191
+ cfg_keys = dict(cfg).keys()
192
+ for k in kwargs:
193
+ if k in cfg_keys:
194
+ # handle booleans
195
+ if isinstance(cfg[k], bool):
196
+ cfg[k] = bool(kwargs[k])
197
+ else:
198
+ cfg[k] = kwargs[k]
199
 
200
  # setup some derived config / hyperparams
201
  cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
 
202
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
203
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
204
+ choose_device(cfg)
205
  cfg.ddp = cfg.world_size != 1
206
  if cfg.ddp:
207
  cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
 
248
  train_dataset = constant_len_dataset["train"]
249
  eval_dataset = constant_len_dataset["test"]
250
 
251
+ # check_dataset_labels(eval_dataset, tokenizer)
252
+
253
  total_num_steps = int(
254
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
255
  )
 
327
  if torch.__version__ >= "2" and sys.platform != "win32":
328
  model = torch.compile(model)
329
 
330
+ # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
331
  signal.signal(
332
  signal.SIGINT,
333
  lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
setup.cfg CHANGED
@@ -17,6 +17,12 @@ install_requires =
17
  fire
18
  PyYAML == 6.0
19
  black
 
 
 
 
 
 
20
 
21
  [options.packages.find]
22
  where = src
 
17
  fire
18
  PyYAML == 6.0
19
  black
20
+ bitsandbytes
21
+ datasets
22
+ accelerate
23
+ sentencepiece
24
+ wandb
25
+ flash-attn
26
 
27
  [options.packages.find]
28
  where = src
src/axolotl/flash_attn.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
2
+
3
+ from typing import List, Optional, Tuple
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ import transformers
9
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
10
+
11
+ from einops import rearrange
12
+
13
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
14
+ from flash_attn.bert_padding import unpad_input, pad_input
15
+
16
+
17
+ def forward(
18
+ self,
19
+ hidden_states: torch.Tensor,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ position_ids: Optional[torch.Tensor] = None,
22
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
23
+ output_attentions: bool = False,
24
+ use_cache: bool = False,
25
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
26
+ """Input shape: Batch x Time x Channel
27
+
28
+ attention_mask: [bsz, q_len]
29
+ """
30
+ bsz, q_len, _ = hidden_states.size()
31
+
32
+ query_states = (
33
+ self.q_proj(hidden_states)
34
+ .view(bsz, q_len, self.num_heads, self.head_dim)
35
+ .transpose(1, 2)
36
+ )
37
+ key_states = (
38
+ self.k_proj(hidden_states)
39
+ .view(bsz, q_len, self.num_heads, self.head_dim)
40
+ .transpose(1, 2)
41
+ )
42
+ value_states = (
43
+ self.v_proj(hidden_states)
44
+ .view(bsz, q_len, self.num_heads, self.head_dim)
45
+ .transpose(1, 2)
46
+ )
47
+ # [bsz, q_len, nh, hd]
48
+ # [bsz, nh, q_len, hd]
49
+
50
+ kv_seq_len = key_states.shape[-2]
51
+ assert past_key_value is None, "past_key_value is not supported"
52
+
53
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
54
+ query_states, key_states = apply_rotary_pos_emb(
55
+ query_states, key_states, cos, sin, position_ids
56
+ )
57
+ # [bsz, nh, t, hd]
58
+ assert not output_attentions, "output_attentions is not supported"
59
+ assert not use_cache, "use_cache is not supported"
60
+
61
+ # Flash attention codes from
62
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
63
+
64
+ # transform the data into the format required by flash attention
65
+ qkv = torch.stack(
66
+ [query_states, key_states, value_states], dim=2
67
+ ) # [bsz, nh, 3, q_len, hd]
68
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
69
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
70
+ # the attention_mask should be the same as the key_padding_mask
71
+ key_padding_mask = attention_mask
72
+
73
+ if key_padding_mask is None:
74
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
75
+ max_s = q_len
76
+ cu_q_lens = torch.arange(
77
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
78
+ )
79
+ output = flash_attn_unpadded_qkvpacked_func(
80
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
81
+ )
82
+ output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
83
+ else:
84
+ nheads = qkv.shape[-2]
85
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
86
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
87
+ x_unpad = rearrange(
88
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
89
+ )
90
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
91
+ x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
92
+ )
93
+ output = rearrange(
94
+ pad_input(
95
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
96
+ ),
97
+ "b s (h d) -> b s h d",
98
+ h=nheads,
99
+ )
100
+ return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
101
+
102
+
103
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
104
+ # requires the attention mask to be the same as the key_padding_mask
105
+ def _prepare_decoder_attention_mask(
106
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
107
+ ):
108
+ # [bsz, seq_len]
109
+ return attention_mask
110
+
111
+
112
+ def replace_llama_attn_with_flash_attn():
113
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
114
+ _prepare_decoder_attention_mask
115
+ )
116
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
src/axolotl/prompt_tokenizers.py CHANGED
@@ -88,5 +88,5 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
88
  def tokenize_prompt(self, prompt):
89
  try:
90
  return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
91
- except (KeyError, AssertionError) as e:
92
  raise InvalidDataException(str(e))
 
88
  def tokenize_prompt(self, prompt):
89
  try:
90
  return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
91
+ except (KeyError, AssertionError, IndexError) as e:
92
  raise InvalidDataException(str(e))