winglian commited on
Commit
a6028d3
1 Parent(s): 8d959a7

black formatting

Browse files
scripts/alpaca_json_to_jsonl.py CHANGED
@@ -6,12 +6,13 @@ import fire
6
  from typing import Optional
7
 
8
  # add src to the pythonpath so we don't need to pip install this
9
- project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
10
- src_dir = os.path.join(project_root, 'src')
11
  sys.path.insert(0, src_dir)
12
 
13
  from axolotl.convert import *
14
 
 
15
  def main(
16
  input: Path,
17
  output: Optional[Path] = None,
@@ -25,9 +26,7 @@ def main(
25
  json_parser = JsonParser()
26
  jsonl_serializer = JsonlSerializer()
27
 
28
- converter = JsonToJsonlConverter(
29
- file_reader, writer, json_parser, jsonl_serializer
30
- )
31
 
32
  converter.convert(input, output)
33
 
 
6
  from typing import Optional
7
 
8
  # add src to the pythonpath so we don't need to pip install this
9
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
10
+ src_dir = os.path.join(project_root, "src")
11
  sys.path.insert(0, src_dir)
12
 
13
  from axolotl.convert import *
14
 
15
+
16
  def main(
17
  input: Path,
18
  output: Optional[Path] = None,
 
26
  json_parser = JsonParser()
27
  jsonl_serializer = JsonlSerializer()
28
 
29
+ converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
 
 
30
 
31
  converter.convert(input, output)
32
 
scripts/finetune.py CHANGED
@@ -14,7 +14,8 @@ from datasets import load_dataset, IterableDataset, Dataset
14
  from peft import (
15
  LoraConfig,
16
  get_peft_model,
17
- prepare_model_for_int8_training, get_peft_model_state_dict,
 
18
  )
19
  from torch import nn
20
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -22,15 +23,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
22
  # add src to the pythonpath so we don't need to pip install this
23
  from transformers.trainer_pt_utils import get_parameter_names
24
 
25
- project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
26
- src_dir = os.path.join(project_root, 'src')
27
  sys.path.insert(0, src_dir)
28
 
29
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
30
- from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \
31
- LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy
 
 
 
 
32
  from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
33
 
 
34
  def setup_wandb_env_vars(cfg):
35
  if len(cfg.wandb_project) > 0:
36
  os.environ["WANDB_PROJECT"] = cfg.wandb_project
@@ -68,7 +74,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
68
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
69
 
70
  if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
71
- tokenizer.add_special_tokens({'pad_token': '[PAD]'})
72
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
73
 
74
  if cfg.load_in_8bit:
@@ -94,11 +100,11 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
94
 
95
 
96
  def train(
97
- config: Path = Path('configs/pythia_1_2B_alpaca.yml'),
98
  **kwargs,
99
  ):
100
  # load the config from the yaml file
101
- with open(config, 'r') as f:
102
  cfg: AttrDict = AttrDict(yaml.load(f, Loader=yaml.Loader))
103
  # if there are any options passed in the cli, if it is something that seems valid from the yaml,
104
  # then overwrite the value
@@ -114,36 +120,52 @@ def train(
114
  cfg.ddp = cfg.world_size != 1
115
  if cfg.ddp:
116
  cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
117
- cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps // cfg.world_size
 
 
118
  setup_wandb_env_vars(cfg)
119
 
120
  # Load the model and tokenizer
121
- model, tokenizer, lora_config = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
 
 
122
  datasets = []
123
  for d in cfg.datasets:
124
- ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, split=None)
 
 
125
  if d.type == "alpaca":
126
- ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
 
 
127
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
128
  datasets.append(ds_wrapper)
129
  elif d.type == "gpteacher":
130
- ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
 
 
131
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
132
  datasets.append(ds_wrapper)
133
  elif d.type == "sharegpt":
134
- ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
 
 
135
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
136
  datasets.append(ds_wrapper)
137
- constant_len_dataset = ConstantLengthDataset(tokenizer, datasets, seq_length=cfg.sequence_len)
138
- constant_len_dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
139
- test_size=cfg.val_set_size, shuffle=True, seed=42
140
  )
 
 
 
141
 
142
  print(constant_len_dataset)
143
  train_dataset = constant_len_dataset["train"]
144
  eval_dataset = constant_len_dataset["test"]
145
 
146
- total_num_steps = int(math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size))
 
 
147
  warmup_steps = min(int(0.03 * total_num_steps), 100)
148
  logging_steps = min(int(0.005 * total_num_steps), 10)
149
  save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
@@ -178,7 +200,9 @@ def train(
178
  "weight_decay": training_args.weight_decay,
179
  },
180
  {
181
- "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
 
 
182
  "weight_decay": 0.0,
183
  },
184
  ]
@@ -210,18 +234,16 @@ def train(
210
 
211
  old_state_dict = model.state_dict
212
  model.state_dict = (
213
- lambda self, *_, **__: get_peft_model_state_dict(
214
- self, old_state_dict()
215
- )
216
  ).__get__(model, type(model))
217
 
218
  if torch.__version__ >= "2" and sys.platform != "win32":
219
  model = torch.compile(model)
220
 
221
- signal.signal(signal.SIGINT, lambda signal, frame: (
222
- model.save_pretrained(cfg.output_dir),
223
- exit(0)
224
- ))
225
 
226
  # go ahead and presave the adapter config
227
  lora_config.save_pretrained(cfg.output_dir)
@@ -229,5 +251,6 @@ def train(
229
 
230
  model.save_pretrained(cfg.output_dir)
231
 
 
232
  if __name__ == "__main__":
233
  fire.Fire(train)
 
14
  from peft import (
15
  LoraConfig,
16
  get_peft_model,
17
+ prepare_model_for_int8_training,
18
+ get_peft_model_state_dict,
19
  )
20
  from torch import nn
21
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
23
  # add src to the pythonpath so we don't need to pip install this
24
  from transformers.trainer_pt_utils import get_parameter_names
25
 
26
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
27
+ src_dir = os.path.join(project_root, "src")
28
  sys.path.insert(0, src_dir)
29
 
30
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
31
+ from axolotl.prompt_tokenizers import (
32
+ AlpacaPromptTokenizingStrategy,
33
+ ShareGPTPromptTokenizingStrategy,
34
+ LLAMA_DEFAULT_PAD_TOKEN,
35
+ GPTeacherPromptTokenizingStrategy,
36
+ )
37
  from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
38
 
39
+
40
  def setup_wandb_env_vars(cfg):
41
  if len(cfg.wandb_project) > 0:
42
  os.environ["WANDB_PROJECT"] = cfg.wandb_project
 
74
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
75
 
76
  if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
77
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
78
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
79
 
80
  if cfg.load_in_8bit:
 
100
 
101
 
102
  def train(
103
+ config: Path = Path("configs/pythia_1_2B_alpaca.yml"),
104
  **kwargs,
105
  ):
106
  # load the config from the yaml file
107
+ with open(config, "r") as f:
108
  cfg: AttrDict = AttrDict(yaml.load(f, Loader=yaml.Loader))
109
  # if there are any options passed in the cli, if it is something that seems valid from the yaml,
110
  # then overwrite the value
 
120
  cfg.ddp = cfg.world_size != 1
121
  if cfg.ddp:
122
  cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
123
+ cfg.gradient_accumulation_steps = (
124
+ cfg.gradient_accumulation_steps // cfg.world_size
125
+ )
126
  setup_wandb_env_vars(cfg)
127
 
128
  # Load the model and tokenizer
129
+ model, tokenizer, lora_config = load_model(
130
+ cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter
131
+ )
132
  datasets = []
133
  for d in cfg.datasets:
134
+ ds: IterableDataset = load_dataset(
135
+ "json", data_files=d.path, streaming=True, split=None
136
+ )
137
  if d.type == "alpaca":
138
+ ds_strategy = AlpacaPromptTokenizingStrategy(
139
+ AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
140
+ )
141
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
142
  datasets.append(ds_wrapper)
143
  elif d.type == "gpteacher":
144
+ ds_strategy = GPTeacherPromptTokenizingStrategy(
145
+ GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
146
+ )
147
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
148
  datasets.append(ds_wrapper)
149
  elif d.type == "sharegpt":
150
+ ds_strategy = ShareGPTPromptTokenizingStrategy(
151
+ ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
152
+ )
153
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
154
  datasets.append(ds_wrapper)
155
+ constant_len_dataset = ConstantLengthDataset(
156
+ tokenizer, datasets, seq_length=cfg.sequence_len
 
157
  )
158
+ constant_len_dataset = Dataset.from_list(
159
+ [_ for _ in constant_len_dataset]
160
+ ).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
161
 
162
  print(constant_len_dataset)
163
  train_dataset = constant_len_dataset["train"]
164
  eval_dataset = constant_len_dataset["test"]
165
 
166
+ total_num_steps = int(
167
+ math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
168
+ )
169
  warmup_steps = min(int(0.03 * total_num_steps), 100)
170
  logging_steps = min(int(0.005 * total_num_steps), 10)
171
  save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
 
200
  "weight_decay": training_args.weight_decay,
201
  },
202
  {
203
+ "params": [
204
+ p for n, p in model.named_parameters() if n not in decay_parameters
205
+ ],
206
  "weight_decay": 0.0,
207
  },
208
  ]
 
234
 
235
  old_state_dict = model.state_dict
236
  model.state_dict = (
237
+ lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
 
 
238
  ).__get__(model, type(model))
239
 
240
  if torch.__version__ >= "2" and sys.platform != "win32":
241
  model = torch.compile(model)
242
 
243
+ signal.signal(
244
+ signal.SIGINT,
245
+ lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
246
+ )
247
 
248
  # go ahead and presave the adapter config
249
  lora_config.save_pretrained(cfg.output_dir)
 
251
 
252
  model.save_pretrained(cfg.output_dir)
253
 
254
+
255
  if __name__ == "__main__":
256
  fire.Fire(train)
src/axolotl/convert.py CHANGED
@@ -47,5 +47,3 @@ class JsonToJsonlConverter:
47
  # data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
48
  jsonl_content = self.jsonl_serializer.serialize(data)
49
  self.file_writer.write(jsonl_content)
50
-
51
-
 
47
  # data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
48
  jsonl_content = self.jsonl_serializer.serialize(data)
49
  self.file_writer.write(jsonl_content)
 
 
src/axolotl/datasets.py CHANGED
@@ -71,10 +71,18 @@ class ConstantLengthDataset(IterableDataset):
71
  else:
72
  example_len = 0
73
 
74
- if not example_len or buffer_len + int(add_concat_token) + example_len > self.seq_length:
 
 
 
 
75
  if buffer["input_ids"]:
76
- input_ids = torch.cat(buffer["input_ids"], dim=-1)[: self.seq_length]
77
- attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[: self.seq_length]
 
 
 
 
78
  labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
79
  yield {
80
  "input_ids": input_ids,
@@ -95,7 +103,9 @@ class ConstantLengthDataset(IterableDataset):
95
  labels.append(self.concat_token_id)
96
 
97
  input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
98
- attention_mask_with_concat = torch.tensor(attention_mask, dtype=torch.long)
 
 
99
  labels_with_concat = torch.tensor(labels, dtype=torch.long)
100
 
101
  buffer["input_ids"].append(input_ids_with_concat)
 
71
  else:
72
  example_len = 0
73
 
74
+ if (
75
+ not example_len
76
+ or buffer_len + int(add_concat_token) + example_len
77
+ > self.seq_length
78
+ ):
79
  if buffer["input_ids"]:
80
+ input_ids = torch.cat(buffer["input_ids"], dim=-1)[
81
+ : self.seq_length
82
+ ]
83
+ attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
84
+ : self.seq_length
85
+ ]
86
  labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
87
  yield {
88
  "input_ids": input_ids,
 
103
  labels.append(self.concat_token_id)
104
 
105
  input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
106
+ attention_mask_with_concat = torch.tensor(
107
+ attention_mask, dtype=torch.long
108
+ )
109
  labels_with_concat = torch.tensor(labels, dtype=torch.long)
110
 
111
  buffer["input_ids"].append(input_ids_with_concat)
src/axolotl/prompt_tokenizers.py CHANGED
@@ -42,7 +42,9 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
42
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
43
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
44
  # TODO this could be sped up using numpy array slicing
45
- tokenized_full_prompt["labels"] = [-100] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
 
 
46
 
47
  return tokenized_full_prompt
48
 
 
42
  tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
43
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
44
  # TODO this could be sped up using numpy array slicing
45
+ tokenized_full_prompt["labels"] = [
46
+ -100
47
+ ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
48
 
49
  return tokenized_full_prompt
50
 
src/axolotl/prompters.py CHANGED
@@ -20,13 +20,9 @@ class AlpacaPrompter:
20
  # returns the full prompt from instruction and optional input
21
  # if a label (=response, =output) is provided, it's also appended.
22
  if input:
23
- res = self.prompt_input.format(
24
- instruction=instruction, input=input
25
- )
26
  else:
27
- res = self.prompt_no_input.format(
28
- instruction=instruction
29
- )
30
  if output:
31
  res = f"{res}{output}"
32
  return res
@@ -41,6 +37,7 @@ class GPTeacherPrompter(AlpacaPrompter):
41
 
42
  class SeparatorStyle(Enum):
43
  """Different separator style."""
 
44
  SINGLE = auto()
45
  TWO = auto()
46
  DOLLY = auto()
@@ -50,6 +47,7 @@ class SeparatorStyle(Enum):
50
  @dataclasses.dataclass
51
  class Conversation:
52
  """A class that keeps all conversation history."""
 
53
  system: str
54
  roles: List[str]
55
  messages: List[List[str]]
@@ -85,7 +83,7 @@ class Conversation:
85
 
86
  conv_vicuna_v1_1 = Conversation(
87
  system="A chat between a curious user and an artificial intelligence assistant. "
88
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
89
  roles=["USER", "ASSISTANT"],
90
  messages=[],
91
  offset=0,
@@ -96,11 +94,7 @@ conv_vicuna_v1_1 = Conversation(
96
 
97
 
98
  class ShareGPTPrompter:
99
- def build_prompt(
100
- self,
101
- source,
102
- tokenizer
103
- ):
104
  if len(source) < 2:
105
  # If there isn't a back and forth conversation, ignore it
106
  # also happens on the data splitting leaving empty conversations
@@ -111,7 +105,10 @@ class ShareGPTPrompter:
111
 
112
  try:
113
  # Apply prompt templates
114
- if source[0]["from"] not in roles or roles[source[0]["from"]] != conv.roles[0]:
 
 
 
115
  # Skip the first one if it is not from human
116
  source = source[1:]
117
  except IndexError as e:
@@ -150,11 +147,19 @@ class ShareGPTPrompter:
150
  parts[0] += sep
151
  round_len = len(tokenizer(rou)["input_ids"])
152
  instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
153
- target[cur_len:cur_len+instruction_len] = [IGNORE_TOKEN_ID] * instruction_len
 
 
154
 
155
  cur_len += round_len
156
  target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
157
- attention_mask = [1 if x != tokenizer.pad_token_id else 0 for x in tokenized_result["input_ids"]]
158
-
159
- return dict(input_ids=tokenized_result["input_ids"], labels=target,
160
- attention_mask=attention_mask)
 
 
 
 
 
 
 
20
  # returns the full prompt from instruction and optional input
21
  # if a label (=response, =output) is provided, it's also appended.
22
  if input:
23
+ res = self.prompt_input.format(instruction=instruction, input=input)
 
 
24
  else:
25
+ res = self.prompt_no_input.format(instruction=instruction)
 
 
26
  if output:
27
  res = f"{res}{output}"
28
  return res
 
37
 
38
  class SeparatorStyle(Enum):
39
  """Different separator style."""
40
+
41
  SINGLE = auto()
42
  TWO = auto()
43
  DOLLY = auto()
 
47
  @dataclasses.dataclass
48
  class Conversation:
49
  """A class that keeps all conversation history."""
50
+
51
  system: str
52
  roles: List[str]
53
  messages: List[List[str]]
 
83
 
84
  conv_vicuna_v1_1 = Conversation(
85
  system="A chat between a curious user and an artificial intelligence assistant. "
86
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
87
  roles=["USER", "ASSISTANT"],
88
  messages=[],
89
  offset=0,
 
94
 
95
 
96
  class ShareGPTPrompter:
97
+ def build_prompt(self, source, tokenizer):
 
 
 
 
98
  if len(source) < 2:
99
  # If there isn't a back and forth conversation, ignore it
100
  # also happens on the data splitting leaving empty conversations
 
105
 
106
  try:
107
  # Apply prompt templates
108
+ if (
109
+ source[0]["from"] not in roles
110
+ or roles[source[0]["from"]] != conv.roles[0]
111
+ ):
112
  # Skip the first one if it is not from human
113
  source = source[1:]
114
  except IndexError as e:
 
147
  parts[0] += sep
148
  round_len = len(tokenizer(rou)["input_ids"])
149
  instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
150
+ target[cur_len : cur_len + instruction_len] = [
151
+ IGNORE_TOKEN_ID
152
+ ] * instruction_len
153
 
154
  cur_len += round_len
155
  target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
156
+ attention_mask = [
157
+ 1 if x != tokenizer.pad_token_id else 0
158
+ for x in tokenized_result["input_ids"]
159
+ ]
160
+
161
+ return dict(
162
+ input_ids=tokenized_result["input_ids"],
163
+ labels=target,
164
+ attention_mask=attention_mask,
165
+ )