winglian commited on
Commit
688c73a
·
unverified ·
2 Parent(s): a27d594 2bc1a5b

Merge pull request #26 from OpenAccess-AI-Collective/mpt-triton

Browse files
scripts/finetune.py CHANGED
@@ -191,7 +191,9 @@ def train(
191
  if cfg.debug:
192
  logging.info("check_dataset_labels...")
193
  check_dataset_labels(
194
- train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]),
 
 
195
  tokenizer,
196
  )
197
 
@@ -218,17 +220,20 @@ def train(
218
  logging.info("Starting trainer...")
219
  resume_from_checkpoint = cfg.resume_from_checkpoint
220
  if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
221
- possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")]
 
 
222
  if len(possible_checkpoints) > 0:
223
- sorted_paths = sorted(possible_checkpoints, key=lambda path: int(path.split('-')[-1]))
 
 
224
  resume_from_checkpoint = sorted_paths[-1]
225
- logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
 
 
226
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
227
 
228
- logging.info(
229
- f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
230
- )
231
-
232
 
233
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
234
  trainer.save_pretrained(cfg.output_dir)
 
191
  if cfg.debug:
192
  logging.info("check_dataset_labels...")
193
  check_dataset_labels(
194
+ train_dataset.select(
195
+ [random.randrange(0, len(train_dataset) - 1) for i in range(5)]
196
+ ),
197
  tokenizer,
198
  )
199
 
 
220
  logging.info("Starting trainer...")
221
  resume_from_checkpoint = cfg.resume_from_checkpoint
222
  if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
223
+ possible_checkpoints = [
224
+ str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
225
+ ]
226
  if len(possible_checkpoints) > 0:
227
+ sorted_paths = sorted(
228
+ possible_checkpoints, key=lambda path: int(path.split("-")[-1])
229
+ )
230
  resume_from_checkpoint = sorted_paths[-1]
231
+ logging.info(
232
+ f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
233
+ )
234
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
235
 
236
+ logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
 
 
 
237
 
238
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
239
  trainer.save_pretrained(cfg.output_dir)
setup.py CHANGED
@@ -10,22 +10,22 @@ with open("./requirements.txt", "r") as requirements_file:
10
  install_requires.append(r)
11
 
12
  setup(
13
- name='axolotl',
14
- version='0.1',
15
  description="You know you're going to axolotl questions",
16
- package_dir={'': 'src'},
17
  packages=find_packages(),
18
  install_requires=install_requires,
19
  extras_require={
20
- 'int4': [
21
  "alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
22
  ],
23
- 'int4_triton': [
24
  "alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
25
  ],
26
- 'extras': [
27
- 'flash-attn',
28
- 'deepspeed',
29
- ]
30
  },
31
  )
 
10
  install_requires.append(r)
11
 
12
  setup(
13
+ name="axolotl",
14
+ version="0.1",
15
  description="You know you're going to axolotl questions",
16
+ package_dir={"": "src"},
17
  packages=find_packages(),
18
  install_requires=install_requires,
19
  extras_require={
20
+ "int4": [
21
  "alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
22
  ],
23
+ "int4_triton": [
24
  "alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
25
  ],
26
+ "extras": [
27
+ "flash-attn",
28
+ "deepspeed",
29
+ ],
30
  },
31
  )
src/axolotl/datasets.py CHANGED
@@ -31,6 +31,7 @@ class TokenizedPromptDataset(IterableDataset):
31
  except InvalidDataException:
32
  pass
33
 
 
34
  # TODO this isn't the best since it can't interleave datasets
35
  class ConstantLengthDataset(IterableDataset):
36
  """
@@ -40,6 +41,7 @@ class ConstantLengthDataset(IterableDataset):
40
  dataset (dataset.Dataset): Dataset with text files.
41
  seq_length (int): Length of token sequences to return.
42
  """
 
43
  def __init__(
44
  self,
45
  tokenizer,
@@ -93,14 +95,19 @@ class ConstantLengthDataset(IterableDataset):
93
  : self.seq_length
94
  ]
95
  labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
96
- if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size():
 
 
 
97
  yield {
98
  "input_ids": input_ids,
99
  "labels": labels,
100
  "attention_mask": attention_mask,
101
  }
102
  else:
103
- logging.warning("dropping batch due to tensor size mismatch")
 
 
104
  buffer = {"input_ids": [], "attention_mask": [], "labels": []}
105
  buffer_len = 0
106
 
@@ -116,11 +123,15 @@ class ConstantLengthDataset(IterableDataset):
116
  attention_mask.append(1)
117
  labels.append(self.concat_token_id)
118
 
119
- input_ids_with_concat = torch.tensor(input_ids, dtype=self.tokens_dtype)
 
 
120
  attention_mask_with_concat = torch.tensor(
121
  attention_mask, dtype=self.tokens_dtype
122
  )
123
- labels_with_concat = torch.tensor(labels, dtype=self.tokens_dtype)
 
 
124
 
125
  buffer["input_ids"].append(input_ids_with_concat)
126
  buffer["attention_mask"].append(attention_mask_with_concat)
 
31
  except InvalidDataException:
32
  pass
33
 
34
+
35
  # TODO this isn't the best since it can't interleave datasets
36
  class ConstantLengthDataset(IterableDataset):
37
  """
 
41
  dataset (dataset.Dataset): Dataset with text files.
42
  seq_length (int): Length of token sequences to return.
43
  """
44
+
45
  def __init__(
46
  self,
47
  tokenizer,
 
95
  : self.seq_length
96
  ]
97
  labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
98
+ if (
99
+ labels.size() == input_ids.size()
100
+ and attention_mask.size() == input_ids.size()
101
+ ):
102
  yield {
103
  "input_ids": input_ids,
104
  "labels": labels,
105
  "attention_mask": attention_mask,
106
  }
107
  else:
108
+ logging.warning(
109
+ "dropping batch due to tensor size mismatch"
110
+ )
111
  buffer = {"input_ids": [], "attention_mask": [], "labels": []}
112
  buffer_len = 0
113
 
 
123
  attention_mask.append(1)
124
  labels.append(self.concat_token_id)
125
 
126
+ input_ids_with_concat = torch.tensor(
127
+ input_ids, dtype=self.tokens_dtype
128
+ )
129
  attention_mask_with_concat = torch.tensor(
130
  attention_mask, dtype=self.tokens_dtype
131
  )
132
+ labels_with_concat = torch.tensor(
133
+ labels, dtype=self.tokens_dtype
134
+ )
135
 
136
  buffer["input_ids"].append(input_ids_with_concat)
137
  buffer["attention_mask"].append(attention_mask_with_concat)
src/axolotl/prompt_tokenizers.py CHANGED
@@ -126,10 +126,8 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
126
 
127
 
128
  class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
129
- def parse_instruction_fields(self, prompt) -> (str):
130
- return (
131
- prompt["text"]
132
- )
133
 
134
  def tokenize_prompt(self, prompt):
135
  instruction = self.parse_instruction_fields(prompt)
@@ -139,9 +137,7 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
139
  return tokenized_full_prompt
140
 
141
  def _build_full_prompt(self, instruction):
142
- return self.prompter.build_prompt(
143
- instruction
144
- )
145
 
146
 
147
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
@@ -149,8 +145,16 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
149
  raise NotImplementedError
150
 
151
  def tokenize_prompt(self, prompt):
152
- instruction, input, output, reflection, corrected = self.parse_instruction_fields(prompt)
153
- full_prompt = self._build_full_prompt(instruction, input, output, reflection, corrected)
 
 
 
 
 
 
 
 
154
  tokenized_full_prompt = self._tokenize(full_prompt)
155
  if not self.train_on_inputs:
156
  user_prompt = self.prompter.build_prompt(
 
126
 
127
 
128
  class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
129
+ def parse_instruction_fields(self, prompt) -> str:
130
+ return prompt["text"]
 
 
131
 
132
  def tokenize_prompt(self, prompt):
133
  instruction = self.parse_instruction_fields(prompt)
 
137
  return tokenized_full_prompt
138
 
139
  def _build_full_prompt(self, instruction):
140
+ return self.prompter.build_prompt(instruction)
 
 
141
 
142
 
143
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
 
145
  raise NotImplementedError
146
 
147
  def tokenize_prompt(self, prompt):
148
+ (
149
+ instruction,
150
+ input,
151
+ output,
152
+ reflection,
153
+ corrected,
154
+ ) = self.parse_instruction_fields(prompt)
155
+ full_prompt = self._build_full_prompt(
156
+ instruction, input, output, reflection, corrected
157
+ )
158
  tokenized_full_prompt = self._tokenize(full_prompt)
159
  if not self.train_on_inputs:
160
  user_prompt = self.prompter.build_prompt(
src/axolotl/prompters.py CHANGED
@@ -36,10 +36,7 @@ class JeopardyPrompter(AlpacaPrompter):
36
 
37
 
38
  class CompletionPrompter(AlpacaPrompter):
39
- def build_prompt(
40
- self,
41
- instruction: str
42
- ) -> str:
43
  return instruction
44
 
45
  def get_response(self, output: str) -> str:
@@ -75,7 +72,9 @@ class ReflectAlpacaPrompter:
75
  else:
76
  res = self.prompt_no_input.format(instruction=instruction)
77
  if output and reflection and corrected:
78
- label = self.agent_label.format(output=output, reflection=reflection, corrected=corrected)
 
 
79
  res = f"{res}{label}"
80
  return res
81
 
@@ -200,9 +199,13 @@ class ShareGPTPrompter:
200
  if len(parts) != 2:
201
  break
202
  parts[0] += sep
203
- round_len = len(tokenizer(rou)["input_ids"]) - 1 # -1 ignores the bos_token generated for this
 
 
204
  # we have to strip the initial part, any dangling whitespace creates an additional ghost token
205
- instruction_len = len(tokenizer(parts[0].strip())["input_ids"]) - 1 # -1 ignores the bos_token generated for this
 
 
206
  target[cur_len : cur_len + instruction_len] = [
207
  IGNORE_TOKEN_ID
208
  ] * instruction_len
@@ -212,7 +215,7 @@ class ShareGPTPrompter:
212
  break
213
 
214
  # Fix: Truncate the target to have the same length as input_ids
215
- target = target[:len(tokenized_result["input_ids"])]
216
  # target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
217
 
218
  attention_mask = [
 
36
 
37
 
38
  class CompletionPrompter(AlpacaPrompter):
39
+ def build_prompt(self, instruction: str) -> str:
 
 
 
40
  return instruction
41
 
42
  def get_response(self, output: str) -> str:
 
72
  else:
73
  res = self.prompt_no_input.format(instruction=instruction)
74
  if output and reflection and corrected:
75
+ label = self.agent_label.format(
76
+ output=output, reflection=reflection, corrected=corrected
77
+ )
78
  res = f"{res}{label}"
79
  return res
80
 
 
199
  if len(parts) != 2:
200
  break
201
  parts[0] += sep
202
+ round_len = (
203
+ len(tokenizer(rou)["input_ids"]) - 1
204
+ ) # -1 ignores the bos_token generated for this
205
  # we have to strip the initial part, any dangling whitespace creates an additional ghost token
206
+ instruction_len = (
207
+ len(tokenizer(parts[0].strip())["input_ids"]) - 1
208
+ ) # -1 ignores the bos_token generated for this
209
  target[cur_len : cur_len + instruction_len] = [
210
  IGNORE_TOKEN_ID
211
  ] * instruction_len
 
215
  break
216
 
217
  # Fix: Truncate the target to have the same length as input_ids
218
+ target = target[: len(tokenized_result["input_ids"])]
219
  # target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
220
 
221
  attention_mask = [
src/axolotl/utils/callbacks.py CHANGED
@@ -1,8 +1,15 @@
1
  import os
2
 
3
- from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
 
 
 
 
 
 
4
  from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
5
 
 
6
  class SavePeftModelCallback(TrainerCallback):
7
  def on_save(
8
  self,
@@ -11,7 +18,9 @@ class SavePeftModelCallback(TrainerCallback):
11
  control: TrainerControl,
12
  **kwargs,
13
  ):
14
- checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
 
 
15
 
16
  peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
17
  kwargs["model"].save_pretrained(peft_model_path)
 
1
  import os
2
 
3
+ from transformers import (
4
+ Seq2SeqTrainer,
5
+ TrainerCallback,
6
+ TrainingArguments,
7
+ TrainerState,
8
+ TrainerControl,
9
+ )
10
  from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
11
 
12
+
13
  class SavePeftModelCallback(TrainerCallback):
14
  def on_save(
15
  self,
 
18
  control: TrainerControl,
19
  **kwargs,
20
  ):
21
+ checkpoint_folder = os.path.join(
22
+ args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
23
+ )
24
 
25
  peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
26
  kwargs["model"].save_pretrained(peft_model_path)
src/axolotl/utils/data.py CHANGED
@@ -2,7 +2,13 @@ import logging
2
  from hashlib import md5
3
  from pathlib import Path
4
 
5
- from datasets import load_from_disk, load_dataset, IterableDataset, Dataset, concatenate_datasets
 
 
 
 
 
 
6
  from huggingface_hub import hf_hub_download
7
 
8
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
@@ -75,7 +81,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
75
  else:
76
  ds = load_dataset(d.path, streaming=True)
77
  else:
78
- fp = hf_hub_download(repo_id=d.path, repo_type="dataset", filename=d.data_files)
 
 
79
  ds = load_dataset("json", data_files=fp, streaming=True, split=None)
80
  if not ds:
81
  raise Exception("unhandled dataset load")
@@ -140,7 +148,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
140
  samples = samples + [i for i in d]
141
  dataset = Dataset.from_list(samples).shuffle(seed=42)
142
  if cfg.local_rank == 0:
143
- logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
 
 
144
  dataset.save_to_disk(prepared_ds_path)
145
 
146
  if cfg.max_packed_sequence_len is not None:
@@ -153,12 +163,14 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
153
  dataset = Dataset.from_list([_ for _ in constant_len_dataset])
154
 
155
  if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
156
- logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
157
- dataset = dataset.shard(num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx)
 
 
 
 
158
 
159
- dataset = dataset.train_test_split(
160
- test_size=cfg.val_set_size, shuffle=False
161
- )
162
  train_dataset = dataset["train"]
163
  eval_dataset = dataset["test"]
164
 
 
2
  from hashlib import md5
3
  from pathlib import Path
4
 
5
+ from datasets import (
6
+ load_from_disk,
7
+ load_dataset,
8
+ IterableDataset,
9
+ Dataset,
10
+ concatenate_datasets,
11
+ )
12
  from huggingface_hub import hf_hub_download
13
 
14
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
 
81
  else:
82
  ds = load_dataset(d.path, streaming=True)
83
  else:
84
+ fp = hf_hub_download(
85
+ repo_id=d.path, repo_type="dataset", filename=d.data_files
86
+ )
87
  ds = load_dataset("json", data_files=fp, streaming=True, split=None)
88
  if not ds:
89
  raise Exception("unhandled dataset load")
 
148
  samples = samples + [i for i in d]
149
  dataset = Dataset.from_list(samples).shuffle(seed=42)
150
  if cfg.local_rank == 0:
151
+ logging.info(
152
+ f"Saving merged prepared dataset to disk... {prepared_ds_path}"
153
+ )
154
  dataset.save_to_disk(prepared_ds_path)
155
 
156
  if cfg.max_packed_sequence_len is not None:
 
163
  dataset = Dataset.from_list([_ for _ in constant_len_dataset])
164
 
165
  if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
166
+ logging.info(
167
+ f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
168
+ )
169
+ dataset = dataset.shard(
170
+ num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
171
+ )
172
 
173
+ dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
 
 
174
  train_dataset = dataset["train"]
175
  eval_dataset = dataset["test"]
176
 
src/axolotl/utils/models.py CHANGED
@@ -9,14 +9,18 @@ from transformers import (
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  PreTrainedModel,
 
12
  )
 
13
  try:
14
  from transformers import (
15
  LlamaForCausalLM,
16
  LlamaTokenizer,
17
  )
18
  except:
19
- logging.warning("This version of transformers does not support Llama. Consider upgrading.")
 
 
20
 
21
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
22
 
@@ -40,7 +44,9 @@ def load_model(
40
  # TODO refactor as a kwarg
41
  load_in_8bit = cfg.load_in_8bit
42
  tokenizer = None
43
- is_llama_derived_model = "llama" in base_model or (cfg.model_type and "llama" in cfg.model_type.lower())
 
 
44
 
45
  if is_llama_derived_model and cfg.flash_attention:
46
  if cfg.device not in ["mps", "cpu"] and inference is False:
@@ -49,11 +55,16 @@ def load_model(
49
  logging.info("patching with flash attention")
50
  replace_llama_attn_with_flash_attn()
51
  elif is_llama_derived_model and cfg.xformers_attention:
52
- from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
 
 
 
53
  logging.info("patching with xformers attention")
54
  hijack_llama_attention()
55
 
56
- torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
 
 
57
  try:
58
  if cfg.load_4bit:
59
  from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@@ -74,8 +85,12 @@ def load_model(
74
  try:
75
  snapshot_download_kwargs = {}
76
  if cfg.base_model_ignore_patterns:
77
- snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
78
- cache_model_path = Path(snapshot_download(base_model, **snapshot_download_kwargs))
 
 
 
 
79
  files = (
80
  list(cache_model_path.glob("*.pt"))
81
  + list(cache_model_path.glob("*.safetensors"))
@@ -116,8 +131,13 @@ def load_model(
116
  trust_remote_code=True if cfg.trust_remote_code is True else False,
117
  )
118
  else:
 
 
 
 
119
  model = AutoModelForCausalLM.from_pretrained(
120
  base_model,
 
121
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
122
  torch_dtype=torch_dtype,
123
  device_map=cfg.device_map,
 
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  PreTrainedModel,
12
+ AutoConfig,
13
  )
14
+
15
  try:
16
  from transformers import (
17
  LlamaForCausalLM,
18
  LlamaTokenizer,
19
  )
20
  except:
21
+ logging.warning(
22
+ "This version of transformers does not support Llama. Consider upgrading."
23
+ )
24
 
25
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
26
 
 
44
  # TODO refactor as a kwarg
45
  load_in_8bit = cfg.load_in_8bit
46
  tokenizer = None
47
+ is_llama_derived_model = "llama" in base_model or (
48
+ cfg.model_type and "llama" in cfg.model_type.lower()
49
+ )
50
 
51
  if is_llama_derived_model and cfg.flash_attention:
52
  if cfg.device not in ["mps", "cpu"] and inference is False:
 
55
  logging.info("patching with flash attention")
56
  replace_llama_attn_with_flash_attn()
57
  elif is_llama_derived_model and cfg.xformers_attention:
58
+ from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
59
+ hijack_llama_attention,
60
+ )
61
+
62
  logging.info("patching with xformers attention")
63
  hijack_llama_attention()
64
 
65
+ torch_dtype = (
66
+ torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
67
+ )
68
  try:
69
  if cfg.load_4bit:
70
  from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
 
85
  try:
86
  snapshot_download_kwargs = {}
87
  if cfg.base_model_ignore_patterns:
88
+ snapshot_download_kwargs[
89
+ "ignore_patterns"
90
+ ] = cfg.base_model_ignore_patterns
91
+ cache_model_path = Path(
92
+ snapshot_download(base_model, **snapshot_download_kwargs)
93
+ )
94
  files = (
95
  list(cache_model_path.glob("*.pt"))
96
  + list(cache_model_path.glob("*.safetensors"))
 
131
  trust_remote_code=True if cfg.trust_remote_code is True else False,
132
  )
133
  else:
134
+ config = AutoConfig.from_pretrained(
135
+ base_model,
136
+ trust_remote_code=True if cfg.trust_remote_code is True else False,
137
+ )
138
  model = AutoModelForCausalLM.from_pretrained(
139
  base_model,
140
+ config=config,
141
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
142
  torch_dtype=torch_dtype,
143
  device_map=cfg.device_map,
src/axolotl/utils/schedulers.py CHANGED
@@ -26,7 +26,10 @@ class InterpolatingLogScheduler(LRScheduler):
26
  if self.last_epoch <= 0:
27
  lrs = [self.min_lr for base_lr in self.base_lrs]
28
  elif self.last_epoch < self.num_steps:
29
- lrs = [self.min_lr * (self.q ** (self.last_epoch - 1)) for base_lr in self.base_lrs]
 
 
 
30
  else:
31
  lrs = [self.max_lr for base_lr in self.base_lrs]
32
 
 
26
  if self.last_epoch <= 0:
27
  lrs = [self.min_lr for base_lr in self.base_lrs]
28
  elif self.last_epoch < self.num_steps:
29
+ lrs = [
30
+ self.min_lr * (self.q ** (self.last_epoch - 1))
31
+ for base_lr in self.base_lrs
32
+ ]
33
  else:
34
  lrs = [self.max_lr for base_lr in self.base_lrs]
35
 
src/axolotl/utils/tokenization.py CHANGED
@@ -1,6 +1,7 @@
1
  from termcolor import colored
2
  import logging
3
 
 
4
  def check_dataset_labels(dataset, tokenizer):
5
  # the dataset is already shuffled, so let's just check the first 5 elements
6
  for idx in range(5):
@@ -11,7 +12,7 @@ def check_example_labels(example, tokenizer):
11
  # Get the input_ids, labels, and attention_mask from the dataset
12
  input_ids = example["input_ids"]
13
  labels = example["labels"]
14
- attention_mask =example["attention_mask"]
15
 
16
  # You can compare the input_ids and labels element-wise
17
  # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
@@ -21,9 +22,7 @@ def check_example_labels(example, tokenizer):
21
  ):
22
  decoded_input_token = tokenizer.decode(input_id)
23
  # Choose the color based on whether the label has the ignore value or not
24
- color = (
25
- "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
26
- )
27
  colored_token = colored(decoded_input_token, color) + colored(
28
  f"({label_id}, {mask}, {input_id})", "white"
29
  )
 
1
  from termcolor import colored
2
  import logging
3
 
4
+
5
  def check_dataset_labels(dataset, tokenizer):
6
  # the dataset is already shuffled, so let's just check the first 5 elements
7
  for idx in range(5):
 
12
  # Get the input_ids, labels, and attention_mask from the dataset
13
  input_ids = example["input_ids"]
14
  labels = example["labels"]
15
+ attention_mask = example["attention_mask"]
16
 
17
  # You can compare the input_ids and labels element-wise
18
  # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
 
22
  ):
23
  decoded_input_token = tokenizer.decode(input_id)
24
  # Choose the color based on whether the label has the ignore value or not
25
+ color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
 
 
26
  colored_token = colored(decoded_input_token, color) + colored(
27
  f"({label_id}, {mask}, {input_id})", "white"
28
  )
src/axolotl/utils/trainer.py CHANGED
@@ -30,16 +30,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
30
  if cfg.logging_steps is not None
31
  else max(min(int(0.005 * total_num_steps), 10), 1)
32
  )
33
- save_steps = (
34
- cfg.save_steps
35
- if cfg.save_steps is not None
36
- else min(int(0.05 * total_num_steps), 200)
37
- )
38
- eval_steps = (
39
- cfg.eval_steps
40
- if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0
41
- else save_steps
42
- )
43
 
44
  training_arguments_kwargs = {}
45
  if cfg.bf16 == "full":
@@ -86,26 +78,33 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
86
 
87
  training_args = transformers.TrainingArguments(
88
  per_device_train_batch_size=cfg.micro_batch_size,
89
- per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None else cfg.micro_batch_size,
 
 
90
  gradient_accumulation_steps=cfg.gradient_accumulation_steps,
91
  eval_accumulation_steps=cfg.gradient_accumulation_steps,
92
  num_train_epochs=cfg.num_epochs,
93
  learning_rate=cfg.learning_rate,
94
  evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
95
- save_strategy="steps",
96
  eval_steps=eval_steps if cfg.val_set_size > 0 else None,
97
  save_steps=save_steps,
98
  output_dir=cfg.output_dir,
99
  save_total_limit=3,
100
  load_best_model_at_end=True
101
- if cfg.val_set_size > 0 and save_steps % eval_steps == 0 and cfg.load_in_8bit is not True
 
 
 
102
  else False,
103
  ddp_find_unused_parameters=False if cfg.ddp else None,
104
  group_by_length=cfg.group_by_length,
105
  report_to="wandb" if cfg.use_wandb else None,
106
  run_name=cfg.wandb_run_id if cfg.use_wandb else None,
107
  optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
108
- lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
 
 
109
  weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
110
  **training_arguments_kwargs,
111
  )
@@ -158,6 +157,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
158
  cfg.learning_rate,
159
  total_steps=total_num_steps,
160
  epochs=cfg.num_epochs,
 
161
  **lr_scheduler_kwargs,
162
  )
163
  elif cfg.lr_scheduler == "log_sweep":
@@ -191,7 +191,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
191
  data_collator_kwargs["pad_to_multiple_of"] = 8
192
 
193
  callbacks = []
194
- if cfg.adapter == 'lora':
195
  callbacks.append(SavePeftModelCallback)
196
 
197
  trainer = transformers.Trainer(
 
30
  if cfg.logging_steps is not None
31
  else max(min(int(0.005 * total_num_steps), 10), 1)
32
  )
33
+ save_steps = cfg.save_steps
34
+ eval_steps = cfg.eval_steps
 
 
 
 
 
 
 
 
35
 
36
  training_arguments_kwargs = {}
37
  if cfg.bf16 == "full":
 
78
 
79
  training_args = transformers.TrainingArguments(
80
  per_device_train_batch_size=cfg.micro_batch_size,
81
+ per_device_eval_batch_size=cfg.eval_batch_size
82
+ if cfg.eval_batch_size is not None
83
+ else cfg.micro_batch_size,
84
  gradient_accumulation_steps=cfg.gradient_accumulation_steps,
85
  eval_accumulation_steps=cfg.gradient_accumulation_steps,
86
  num_train_epochs=cfg.num_epochs,
87
  learning_rate=cfg.learning_rate,
88
  evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
89
+ save_strategy="steps" if save_steps else "epoch",
90
  eval_steps=eval_steps if cfg.val_set_size > 0 else None,
91
  save_steps=save_steps,
92
  output_dir=cfg.output_dir,
93
  save_total_limit=3,
94
  load_best_model_at_end=True
95
+ if cfg.val_set_size > 0
96
+ and save_steps is not None
97
+ and save_steps % eval_steps == 0
98
+ and cfg.load_in_8bit is not True
99
  else False,
100
  ddp_find_unused_parameters=False if cfg.ddp else None,
101
  group_by_length=cfg.group_by_length,
102
  report_to="wandb" if cfg.use_wandb else None,
103
  run_name=cfg.wandb_run_id if cfg.use_wandb else None,
104
  optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
105
+ lr_scheduler_type=cfg.lr_scheduler
106
+ if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
107
+ else "cosine",
108
  weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
109
  **training_arguments_kwargs,
110
  )
 
157
  cfg.learning_rate,
158
  total_steps=total_num_steps,
159
  epochs=cfg.num_epochs,
160
+ div_factor=10,
161
  **lr_scheduler_kwargs,
162
  )
163
  elif cfg.lr_scheduler == "log_sweep":
 
191
  data_collator_kwargs["pad_to_multiple_of"] = 8
192
 
193
  callbacks = []
194
+ if cfg.adapter == "lora":
195
  callbacks.append(SavePeftModelCallback)
196
 
197
  trainer = transformers.Trainer(
src/axolotl/utils/wandb.py CHANGED
@@ -2,7 +2,9 @@ import os
2
 
3
 
4
  def setup_wandb_env_vars(cfg):
5
- if cfg.wandb_project and len(cfg.wandb_project) > 0:
 
 
6
  os.environ["WANDB_PROJECT"] = cfg.wandb_project
7
  cfg.use_wandb = True
8
  if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
 
2
 
3
 
4
  def setup_wandb_env_vars(cfg):
5
+ if cfg.wandb_mode and cfg.wandb_mode == "offline":
6
+ os.environ["WANDB_MODE"] = cfg.wandb_mode
7
+ elif cfg.wandb_project and len(cfg.wandb_project) > 0:
8
  os.environ["WANDB_PROJECT"] = cfg.wandb_project
9
  cfg.use_wandb = True
10
  if cfg.wandb_watch and len(cfg.wandb_watch) > 0: