Merge pull request #26 from OpenAccess-AI-Collective/mpt-triton
Browse files- scripts/finetune.py +13 -8
- setup.py +9 -9
- src/axolotl/datasets.py +15 -4
- src/axolotl/prompt_tokenizers.py +13 -9
- src/axolotl/prompters.py +11 -8
- src/axolotl/utils/callbacks.py +11 -2
- src/axolotl/utils/data.py +20 -8
- src/axolotl/utils/models.py +26 -6
- src/axolotl/utils/schedulers.py +4 -1
- src/axolotl/utils/tokenization.py +3 -4
- src/axolotl/utils/trainer.py +15 -15
- src/axolotl/utils/wandb.py +3 -1
scripts/finetune.py
CHANGED
@@ -191,7 +191,9 @@ def train(
|
|
191 |
if cfg.debug:
|
192 |
logging.info("check_dataset_labels...")
|
193 |
check_dataset_labels(
|
194 |
-
train_dataset.select(
|
|
|
|
|
195 |
tokenizer,
|
196 |
)
|
197 |
|
@@ -218,17 +220,20 @@ def train(
|
|
218 |
logging.info("Starting trainer...")
|
219 |
resume_from_checkpoint = cfg.resume_from_checkpoint
|
220 |
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
221 |
-
possible_checkpoints = [
|
|
|
|
|
222 |
if len(possible_checkpoints) > 0:
|
223 |
-
sorted_paths = sorted(
|
|
|
|
|
224 |
resume_from_checkpoint = sorted_paths[-1]
|
225 |
-
logging.info(
|
|
|
|
|
226 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
227 |
|
228 |
-
logging.info(
|
229 |
-
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
|
230 |
-
)
|
231 |
-
|
232 |
|
233 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
234 |
trainer.save_pretrained(cfg.output_dir)
|
|
|
191 |
if cfg.debug:
|
192 |
logging.info("check_dataset_labels...")
|
193 |
check_dataset_labels(
|
194 |
+
train_dataset.select(
|
195 |
+
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
|
196 |
+
),
|
197 |
tokenizer,
|
198 |
)
|
199 |
|
|
|
220 |
logging.info("Starting trainer...")
|
221 |
resume_from_checkpoint = cfg.resume_from_checkpoint
|
222 |
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
223 |
+
possible_checkpoints = [
|
224 |
+
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
225 |
+
]
|
226 |
if len(possible_checkpoints) > 0:
|
227 |
+
sorted_paths = sorted(
|
228 |
+
possible_checkpoints, key=lambda path: int(path.split("-")[-1])
|
229 |
+
)
|
230 |
resume_from_checkpoint = sorted_paths[-1]
|
231 |
+
logging.info(
|
232 |
+
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
233 |
+
)
|
234 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
235 |
|
236 |
+
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
|
|
|
|
|
|
237 |
|
238 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
239 |
trainer.save_pretrained(cfg.output_dir)
|
setup.py
CHANGED
@@ -10,22 +10,22 @@ with open("./requirements.txt", "r") as requirements_file:
|
|
10 |
install_requires.append(r)
|
11 |
|
12 |
setup(
|
13 |
-
name=
|
14 |
-
version=
|
15 |
description="You know you're going to axolotl questions",
|
16 |
-
package_dir={
|
17 |
packages=find_packages(),
|
18 |
install_requires=install_requires,
|
19 |
extras_require={
|
20 |
-
|
21 |
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
22 |
],
|
23 |
-
|
24 |
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
25 |
],
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
]
|
30 |
},
|
31 |
)
|
|
|
10 |
install_requires.append(r)
|
11 |
|
12 |
setup(
|
13 |
+
name="axolotl",
|
14 |
+
version="0.1",
|
15 |
description="You know you're going to axolotl questions",
|
16 |
+
package_dir={"": "src"},
|
17 |
packages=find_packages(),
|
18 |
install_requires=install_requires,
|
19 |
extras_require={
|
20 |
+
"int4": [
|
21 |
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
22 |
],
|
23 |
+
"int4_triton": [
|
24 |
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
25 |
],
|
26 |
+
"extras": [
|
27 |
+
"flash-attn",
|
28 |
+
"deepspeed",
|
29 |
+
],
|
30 |
},
|
31 |
)
|
src/axolotl/datasets.py
CHANGED
@@ -31,6 +31,7 @@ class TokenizedPromptDataset(IterableDataset):
|
|
31 |
except InvalidDataException:
|
32 |
pass
|
33 |
|
|
|
34 |
# TODO this isn't the best since it can't interleave datasets
|
35 |
class ConstantLengthDataset(IterableDataset):
|
36 |
"""
|
@@ -40,6 +41,7 @@ class ConstantLengthDataset(IterableDataset):
|
|
40 |
dataset (dataset.Dataset): Dataset with text files.
|
41 |
seq_length (int): Length of token sequences to return.
|
42 |
"""
|
|
|
43 |
def __init__(
|
44 |
self,
|
45 |
tokenizer,
|
@@ -93,14 +95,19 @@ class ConstantLengthDataset(IterableDataset):
|
|
93 |
: self.seq_length
|
94 |
]
|
95 |
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
96 |
-
if
|
|
|
|
|
|
|
97 |
yield {
|
98 |
"input_ids": input_ids,
|
99 |
"labels": labels,
|
100 |
"attention_mask": attention_mask,
|
101 |
}
|
102 |
else:
|
103 |
-
logging.warning(
|
|
|
|
|
104 |
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
105 |
buffer_len = 0
|
106 |
|
@@ -116,11 +123,15 @@ class ConstantLengthDataset(IterableDataset):
|
|
116 |
attention_mask.append(1)
|
117 |
labels.append(self.concat_token_id)
|
118 |
|
119 |
-
input_ids_with_concat = torch.tensor(
|
|
|
|
|
120 |
attention_mask_with_concat = torch.tensor(
|
121 |
attention_mask, dtype=self.tokens_dtype
|
122 |
)
|
123 |
-
labels_with_concat = torch.tensor(
|
|
|
|
|
124 |
|
125 |
buffer["input_ids"].append(input_ids_with_concat)
|
126 |
buffer["attention_mask"].append(attention_mask_with_concat)
|
|
|
31 |
except InvalidDataException:
|
32 |
pass
|
33 |
|
34 |
+
|
35 |
# TODO this isn't the best since it can't interleave datasets
|
36 |
class ConstantLengthDataset(IterableDataset):
|
37 |
"""
|
|
|
41 |
dataset (dataset.Dataset): Dataset with text files.
|
42 |
seq_length (int): Length of token sequences to return.
|
43 |
"""
|
44 |
+
|
45 |
def __init__(
|
46 |
self,
|
47 |
tokenizer,
|
|
|
95 |
: self.seq_length
|
96 |
]
|
97 |
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
98 |
+
if (
|
99 |
+
labels.size() == input_ids.size()
|
100 |
+
and attention_mask.size() == input_ids.size()
|
101 |
+
):
|
102 |
yield {
|
103 |
"input_ids": input_ids,
|
104 |
"labels": labels,
|
105 |
"attention_mask": attention_mask,
|
106 |
}
|
107 |
else:
|
108 |
+
logging.warning(
|
109 |
+
"dropping batch due to tensor size mismatch"
|
110 |
+
)
|
111 |
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
112 |
buffer_len = 0
|
113 |
|
|
|
123 |
attention_mask.append(1)
|
124 |
labels.append(self.concat_token_id)
|
125 |
|
126 |
+
input_ids_with_concat = torch.tensor(
|
127 |
+
input_ids, dtype=self.tokens_dtype
|
128 |
+
)
|
129 |
attention_mask_with_concat = torch.tensor(
|
130 |
attention_mask, dtype=self.tokens_dtype
|
131 |
)
|
132 |
+
labels_with_concat = torch.tensor(
|
133 |
+
labels, dtype=self.tokens_dtype
|
134 |
+
)
|
135 |
|
136 |
buffer["input_ids"].append(input_ids_with_concat)
|
137 |
buffer["attention_mask"].append(attention_mask_with_concat)
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -126,10 +126,8 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
126 |
|
127 |
|
128 |
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
129 |
-
def parse_instruction_fields(self, prompt) ->
|
130 |
-
return
|
131 |
-
prompt["text"]
|
132 |
-
)
|
133 |
|
134 |
def tokenize_prompt(self, prompt):
|
135 |
instruction = self.parse_instruction_fields(prompt)
|
@@ -139,9 +137,7 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
139 |
return tokenized_full_prompt
|
140 |
|
141 |
def _build_full_prompt(self, instruction):
|
142 |
-
return self.prompter.build_prompt(
|
143 |
-
instruction
|
144 |
-
)
|
145 |
|
146 |
|
147 |
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
@@ -149,8 +145,16 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
149 |
raise NotImplementedError
|
150 |
|
151 |
def tokenize_prompt(self, prompt):
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
155 |
if not self.train_on_inputs:
|
156 |
user_prompt = self.prompter.build_prompt(
|
|
|
126 |
|
127 |
|
128 |
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
129 |
+
def parse_instruction_fields(self, prompt) -> str:
|
130 |
+
return prompt["text"]
|
|
|
|
|
131 |
|
132 |
def tokenize_prompt(self, prompt):
|
133 |
instruction = self.parse_instruction_fields(prompt)
|
|
|
137 |
return tokenized_full_prompt
|
138 |
|
139 |
def _build_full_prompt(self, instruction):
|
140 |
+
return self.prompter.build_prompt(instruction)
|
|
|
|
|
141 |
|
142 |
|
143 |
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
|
145 |
raise NotImplementedError
|
146 |
|
147 |
def tokenize_prompt(self, prompt):
|
148 |
+
(
|
149 |
+
instruction,
|
150 |
+
input,
|
151 |
+
output,
|
152 |
+
reflection,
|
153 |
+
corrected,
|
154 |
+
) = self.parse_instruction_fields(prompt)
|
155 |
+
full_prompt = self._build_full_prompt(
|
156 |
+
instruction, input, output, reflection, corrected
|
157 |
+
)
|
158 |
tokenized_full_prompt = self._tokenize(full_prompt)
|
159 |
if not self.train_on_inputs:
|
160 |
user_prompt = self.prompter.build_prompt(
|
src/axolotl/prompters.py
CHANGED
@@ -36,10 +36,7 @@ class JeopardyPrompter(AlpacaPrompter):
|
|
36 |
|
37 |
|
38 |
class CompletionPrompter(AlpacaPrompter):
|
39 |
-
def build_prompt(
|
40 |
-
self,
|
41 |
-
instruction: str
|
42 |
-
) -> str:
|
43 |
return instruction
|
44 |
|
45 |
def get_response(self, output: str) -> str:
|
@@ -75,7 +72,9 @@ class ReflectAlpacaPrompter:
|
|
75 |
else:
|
76 |
res = self.prompt_no_input.format(instruction=instruction)
|
77 |
if output and reflection and corrected:
|
78 |
-
label = self.agent_label.format(
|
|
|
|
|
79 |
res = f"{res}{label}"
|
80 |
return res
|
81 |
|
@@ -200,9 +199,13 @@ class ShareGPTPrompter:
|
|
200 |
if len(parts) != 2:
|
201 |
break
|
202 |
parts[0] += sep
|
203 |
-
round_len =
|
|
|
|
|
204 |
# we have to strip the initial part, any dangling whitespace creates an additional ghost token
|
205 |
-
instruction_len =
|
|
|
|
|
206 |
target[cur_len : cur_len + instruction_len] = [
|
207 |
IGNORE_TOKEN_ID
|
208 |
] * instruction_len
|
@@ -212,7 +215,7 @@ class ShareGPTPrompter:
|
|
212 |
break
|
213 |
|
214 |
# Fix: Truncate the target to have the same length as input_ids
|
215 |
-
target = target[:len(tokenized_result["input_ids"])]
|
216 |
# target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
|
217 |
|
218 |
attention_mask = [
|
|
|
36 |
|
37 |
|
38 |
class CompletionPrompter(AlpacaPrompter):
|
39 |
+
def build_prompt(self, instruction: str) -> str:
|
|
|
|
|
|
|
40 |
return instruction
|
41 |
|
42 |
def get_response(self, output: str) -> str:
|
|
|
72 |
else:
|
73 |
res = self.prompt_no_input.format(instruction=instruction)
|
74 |
if output and reflection and corrected:
|
75 |
+
label = self.agent_label.format(
|
76 |
+
output=output, reflection=reflection, corrected=corrected
|
77 |
+
)
|
78 |
res = f"{res}{label}"
|
79 |
return res
|
80 |
|
|
|
199 |
if len(parts) != 2:
|
200 |
break
|
201 |
parts[0] += sep
|
202 |
+
round_len = (
|
203 |
+
len(tokenizer(rou)["input_ids"]) - 1
|
204 |
+
) # -1 ignores the bos_token generated for this
|
205 |
# we have to strip the initial part, any dangling whitespace creates an additional ghost token
|
206 |
+
instruction_len = (
|
207 |
+
len(tokenizer(parts[0].strip())["input_ids"]) - 1
|
208 |
+
) # -1 ignores the bos_token generated for this
|
209 |
target[cur_len : cur_len + instruction_len] = [
|
210 |
IGNORE_TOKEN_ID
|
211 |
] * instruction_len
|
|
|
215 |
break
|
216 |
|
217 |
# Fix: Truncate the target to have the same length as input_ids
|
218 |
+
target = target[: len(tokenized_result["input_ids"])]
|
219 |
# target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
|
220 |
|
221 |
attention_mask = [
|
src/axolotl/utils/callbacks.py
CHANGED
@@ -1,8 +1,15 @@
|
|
1 |
import os
|
2 |
|
3 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
5 |
|
|
|
6 |
class SavePeftModelCallback(TrainerCallback):
|
7 |
def on_save(
|
8 |
self,
|
@@ -11,7 +18,9 @@ class SavePeftModelCallback(TrainerCallback):
|
|
11 |
control: TrainerControl,
|
12 |
**kwargs,
|
13 |
):
|
14 |
-
checkpoint_folder = os.path.join(
|
|
|
|
|
15 |
|
16 |
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
17 |
kwargs["model"].save_pretrained(peft_model_path)
|
|
|
1 |
import os
|
2 |
|
3 |
+
from transformers import (
|
4 |
+
Seq2SeqTrainer,
|
5 |
+
TrainerCallback,
|
6 |
+
TrainingArguments,
|
7 |
+
TrainerState,
|
8 |
+
TrainerControl,
|
9 |
+
)
|
10 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
11 |
|
12 |
+
|
13 |
class SavePeftModelCallback(TrainerCallback):
|
14 |
def on_save(
|
15 |
self,
|
|
|
18 |
control: TrainerControl,
|
19 |
**kwargs,
|
20 |
):
|
21 |
+
checkpoint_folder = os.path.join(
|
22 |
+
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
|
23 |
+
)
|
24 |
|
25 |
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
26 |
kwargs["model"].save_pretrained(peft_model_path)
|
src/axolotl/utils/data.py
CHANGED
@@ -2,7 +2,13 @@ import logging
|
|
2 |
from hashlib import md5
|
3 |
from pathlib import Path
|
4 |
|
5 |
-
from datasets import
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
@@ -75,7 +81,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
75 |
else:
|
76 |
ds = load_dataset(d.path, streaming=True)
|
77 |
else:
|
78 |
-
fp = hf_hub_download(
|
|
|
|
|
79 |
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
|
80 |
if not ds:
|
81 |
raise Exception("unhandled dataset load")
|
@@ -140,7 +148,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
140 |
samples = samples + [i for i in d]
|
141 |
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
142 |
if cfg.local_rank == 0:
|
143 |
-
logging.info(
|
|
|
|
|
144 |
dataset.save_to_disk(prepared_ds_path)
|
145 |
|
146 |
if cfg.max_packed_sequence_len is not None:
|
@@ -153,12 +163,14 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
|
153 |
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
154 |
|
155 |
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
156 |
-
logging.info(
|
157 |
-
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
dataset = dataset.train_test_split(
|
160 |
-
test_size=cfg.val_set_size, shuffle=False
|
161 |
-
)
|
162 |
train_dataset = dataset["train"]
|
163 |
eval_dataset = dataset["test"]
|
164 |
|
|
|
2 |
from hashlib import md5
|
3 |
from pathlib import Path
|
4 |
|
5 |
+
from datasets import (
|
6 |
+
load_from_disk,
|
7 |
+
load_dataset,
|
8 |
+
IterableDataset,
|
9 |
+
Dataset,
|
10 |
+
concatenate_datasets,
|
11 |
+
)
|
12 |
from huggingface_hub import hf_hub_download
|
13 |
|
14 |
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
|
|
81 |
else:
|
82 |
ds = load_dataset(d.path, streaming=True)
|
83 |
else:
|
84 |
+
fp = hf_hub_download(
|
85 |
+
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
86 |
+
)
|
87 |
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
|
88 |
if not ds:
|
89 |
raise Exception("unhandled dataset load")
|
|
|
148 |
samples = samples + [i for i in d]
|
149 |
dataset = Dataset.from_list(samples).shuffle(seed=42)
|
150 |
if cfg.local_rank == 0:
|
151 |
+
logging.info(
|
152 |
+
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
153 |
+
)
|
154 |
dataset.save_to_disk(prepared_ds_path)
|
155 |
|
156 |
if cfg.max_packed_sequence_len is not None:
|
|
|
163 |
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
164 |
|
165 |
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
166 |
+
logging.info(
|
167 |
+
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
168 |
+
)
|
169 |
+
dataset = dataset.shard(
|
170 |
+
num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
|
171 |
+
)
|
172 |
|
173 |
+
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
|
|
|
|
174 |
train_dataset = dataset["train"]
|
175 |
eval_dataset = dataset["test"]
|
176 |
|
src/axolotl/utils/models.py
CHANGED
@@ -9,14 +9,18 @@ from transformers import (
|
|
9 |
AutoModelForCausalLM,
|
10 |
AutoTokenizer,
|
11 |
PreTrainedModel,
|
|
|
12 |
)
|
|
|
13 |
try:
|
14 |
from transformers import (
|
15 |
LlamaForCausalLM,
|
16 |
LlamaTokenizer,
|
17 |
)
|
18 |
except:
|
19 |
-
logging.warning(
|
|
|
|
|
20 |
|
21 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
22 |
|
@@ -40,7 +44,9 @@ def load_model(
|
|
40 |
# TODO refactor as a kwarg
|
41 |
load_in_8bit = cfg.load_in_8bit
|
42 |
tokenizer = None
|
43 |
-
is_llama_derived_model = "llama" in base_model or (
|
|
|
|
|
44 |
|
45 |
if is_llama_derived_model and cfg.flash_attention:
|
46 |
if cfg.device not in ["mps", "cpu"] and inference is False:
|
@@ -49,11 +55,16 @@ def load_model(
|
|
49 |
logging.info("patching with flash attention")
|
50 |
replace_llama_attn_with_flash_attn()
|
51 |
elif is_llama_derived_model and cfg.xformers_attention:
|
52 |
-
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import
|
|
|
|
|
|
|
53 |
logging.info("patching with xformers attention")
|
54 |
hijack_llama_attention()
|
55 |
|
56 |
-
torch_dtype =
|
|
|
|
|
57 |
try:
|
58 |
if cfg.load_4bit:
|
59 |
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
@@ -74,8 +85,12 @@ def load_model(
|
|
74 |
try:
|
75 |
snapshot_download_kwargs = {}
|
76 |
if cfg.base_model_ignore_patterns:
|
77 |
-
snapshot_download_kwargs[
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
files = (
|
80 |
list(cache_model_path.glob("*.pt"))
|
81 |
+ list(cache_model_path.glob("*.safetensors"))
|
@@ -116,8 +131,13 @@ def load_model(
|
|
116 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
117 |
)
|
118 |
else:
|
|
|
|
|
|
|
|
|
119 |
model = AutoModelForCausalLM.from_pretrained(
|
120 |
base_model,
|
|
|
121 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
122 |
torch_dtype=torch_dtype,
|
123 |
device_map=cfg.device_map,
|
|
|
9 |
AutoModelForCausalLM,
|
10 |
AutoTokenizer,
|
11 |
PreTrainedModel,
|
12 |
+
AutoConfig,
|
13 |
)
|
14 |
+
|
15 |
try:
|
16 |
from transformers import (
|
17 |
LlamaForCausalLM,
|
18 |
LlamaTokenizer,
|
19 |
)
|
20 |
except:
|
21 |
+
logging.warning(
|
22 |
+
"This version of transformers does not support Llama. Consider upgrading."
|
23 |
+
)
|
24 |
|
25 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
26 |
|
|
|
44 |
# TODO refactor as a kwarg
|
45 |
load_in_8bit = cfg.load_in_8bit
|
46 |
tokenizer = None
|
47 |
+
is_llama_derived_model = "llama" in base_model or (
|
48 |
+
cfg.model_type and "llama" in cfg.model_type.lower()
|
49 |
+
)
|
50 |
|
51 |
if is_llama_derived_model and cfg.flash_attention:
|
52 |
if cfg.device not in ["mps", "cpu"] and inference is False:
|
|
|
55 |
logging.info("patching with flash attention")
|
56 |
replace_llama_attn_with_flash_attn()
|
57 |
elif is_llama_derived_model and cfg.xformers_attention:
|
58 |
+
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
|
59 |
+
hijack_llama_attention,
|
60 |
+
)
|
61 |
+
|
62 |
logging.info("patching with xformers attention")
|
63 |
hijack_llama_attention()
|
64 |
|
65 |
+
torch_dtype = (
|
66 |
+
torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
|
67 |
+
)
|
68 |
try:
|
69 |
if cfg.load_4bit:
|
70 |
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
|
|
85 |
try:
|
86 |
snapshot_download_kwargs = {}
|
87 |
if cfg.base_model_ignore_patterns:
|
88 |
+
snapshot_download_kwargs[
|
89 |
+
"ignore_patterns"
|
90 |
+
] = cfg.base_model_ignore_patterns
|
91 |
+
cache_model_path = Path(
|
92 |
+
snapshot_download(base_model, **snapshot_download_kwargs)
|
93 |
+
)
|
94 |
files = (
|
95 |
list(cache_model_path.glob("*.pt"))
|
96 |
+ list(cache_model_path.glob("*.safetensors"))
|
|
|
131 |
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
132 |
)
|
133 |
else:
|
134 |
+
config = AutoConfig.from_pretrained(
|
135 |
+
base_model,
|
136 |
+
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
137 |
+
)
|
138 |
model = AutoModelForCausalLM.from_pretrained(
|
139 |
base_model,
|
140 |
+
config=config,
|
141 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
142 |
torch_dtype=torch_dtype,
|
143 |
device_map=cfg.device_map,
|
src/axolotl/utils/schedulers.py
CHANGED
@@ -26,7 +26,10 @@ class InterpolatingLogScheduler(LRScheduler):
|
|
26 |
if self.last_epoch <= 0:
|
27 |
lrs = [self.min_lr for base_lr in self.base_lrs]
|
28 |
elif self.last_epoch < self.num_steps:
|
29 |
-
lrs = [
|
|
|
|
|
|
|
30 |
else:
|
31 |
lrs = [self.max_lr for base_lr in self.base_lrs]
|
32 |
|
|
|
26 |
if self.last_epoch <= 0:
|
27 |
lrs = [self.min_lr for base_lr in self.base_lrs]
|
28 |
elif self.last_epoch < self.num_steps:
|
29 |
+
lrs = [
|
30 |
+
self.min_lr * (self.q ** (self.last_epoch - 1))
|
31 |
+
for base_lr in self.base_lrs
|
32 |
+
]
|
33 |
else:
|
34 |
lrs = [self.max_lr for base_lr in self.base_lrs]
|
35 |
|
src/axolotl/utils/tokenization.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from termcolor import colored
|
2 |
import logging
|
3 |
|
|
|
4 |
def check_dataset_labels(dataset, tokenizer):
|
5 |
# the dataset is already shuffled, so let's just check the first 5 elements
|
6 |
for idx in range(5):
|
@@ -11,7 +12,7 @@ def check_example_labels(example, tokenizer):
|
|
11 |
# Get the input_ids, labels, and attention_mask from the dataset
|
12 |
input_ids = example["input_ids"]
|
13 |
labels = example["labels"]
|
14 |
-
attention_mask =example["attention_mask"]
|
15 |
|
16 |
# You can compare the input_ids and labels element-wise
|
17 |
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
|
@@ -21,9 +22,7 @@ def check_example_labels(example, tokenizer):
|
|
21 |
):
|
22 |
decoded_input_token = tokenizer.decode(input_id)
|
23 |
# Choose the color based on whether the label has the ignore value or not
|
24 |
-
color = (
|
25 |
-
"red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
|
26 |
-
)
|
27 |
colored_token = colored(decoded_input_token, color) + colored(
|
28 |
f"({label_id}, {mask}, {input_id})", "white"
|
29 |
)
|
|
|
1 |
from termcolor import colored
|
2 |
import logging
|
3 |
|
4 |
+
|
5 |
def check_dataset_labels(dataset, tokenizer):
|
6 |
# the dataset is already shuffled, so let's just check the first 5 elements
|
7 |
for idx in range(5):
|
|
|
12 |
# Get the input_ids, labels, and attention_mask from the dataset
|
13 |
input_ids = example["input_ids"]
|
14 |
labels = example["labels"]
|
15 |
+
attention_mask = example["attention_mask"]
|
16 |
|
17 |
# You can compare the input_ids and labels element-wise
|
18 |
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
|
|
|
22 |
):
|
23 |
decoded_input_token = tokenizer.decode(input_id)
|
24 |
# Choose the color based on whether the label has the ignore value or not
|
25 |
+
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
|
|
|
|
|
26 |
colored_token = colored(decoded_input_token, color) + colored(
|
27 |
f"({label_id}, {mask}, {input_id})", "white"
|
28 |
)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -30,16 +30,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
30 |
if cfg.logging_steps is not None
|
31 |
else max(min(int(0.005 * total_num_steps), 10), 1)
|
32 |
)
|
33 |
-
save_steps =
|
34 |
-
|
35 |
-
if cfg.save_steps is not None
|
36 |
-
else min(int(0.05 * total_num_steps), 200)
|
37 |
-
)
|
38 |
-
eval_steps = (
|
39 |
-
cfg.eval_steps
|
40 |
-
if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0
|
41 |
-
else save_steps
|
42 |
-
)
|
43 |
|
44 |
training_arguments_kwargs = {}
|
45 |
if cfg.bf16 == "full":
|
@@ -86,26 +78,33 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
86 |
|
87 |
training_args = transformers.TrainingArguments(
|
88 |
per_device_train_batch_size=cfg.micro_batch_size,
|
89 |
-
per_device_eval_batch_size=cfg.eval_batch_size
|
|
|
|
|
90 |
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
91 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
92 |
num_train_epochs=cfg.num_epochs,
|
93 |
learning_rate=cfg.learning_rate,
|
94 |
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
95 |
-
save_strategy="steps",
|
96 |
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
|
97 |
save_steps=save_steps,
|
98 |
output_dir=cfg.output_dir,
|
99 |
save_total_limit=3,
|
100 |
load_best_model_at_end=True
|
101 |
-
if cfg.val_set_size > 0
|
|
|
|
|
|
|
102 |
else False,
|
103 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
104 |
group_by_length=cfg.group_by_length,
|
105 |
report_to="wandb" if cfg.use_wandb else None,
|
106 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
107 |
optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
|
108 |
-
lr_scheduler_type=cfg.lr_scheduler
|
|
|
|
|
109 |
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
110 |
**training_arguments_kwargs,
|
111 |
)
|
@@ -158,6 +157,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
158 |
cfg.learning_rate,
|
159 |
total_steps=total_num_steps,
|
160 |
epochs=cfg.num_epochs,
|
|
|
161 |
**lr_scheduler_kwargs,
|
162 |
)
|
163 |
elif cfg.lr_scheduler == "log_sweep":
|
@@ -191,7 +191,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
191 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
192 |
|
193 |
callbacks = []
|
194 |
-
if cfg.adapter ==
|
195 |
callbacks.append(SavePeftModelCallback)
|
196 |
|
197 |
trainer = transformers.Trainer(
|
|
|
30 |
if cfg.logging_steps is not None
|
31 |
else max(min(int(0.005 * total_num_steps), 10), 1)
|
32 |
)
|
33 |
+
save_steps = cfg.save_steps
|
34 |
+
eval_steps = cfg.eval_steps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
training_arguments_kwargs = {}
|
37 |
if cfg.bf16 == "full":
|
|
|
78 |
|
79 |
training_args = transformers.TrainingArguments(
|
80 |
per_device_train_batch_size=cfg.micro_batch_size,
|
81 |
+
per_device_eval_batch_size=cfg.eval_batch_size
|
82 |
+
if cfg.eval_batch_size is not None
|
83 |
+
else cfg.micro_batch_size,
|
84 |
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
85 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
86 |
num_train_epochs=cfg.num_epochs,
|
87 |
learning_rate=cfg.learning_rate,
|
88 |
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
89 |
+
save_strategy="steps" if save_steps else "epoch",
|
90 |
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
|
91 |
save_steps=save_steps,
|
92 |
output_dir=cfg.output_dir,
|
93 |
save_total_limit=3,
|
94 |
load_best_model_at_end=True
|
95 |
+
if cfg.val_set_size > 0
|
96 |
+
and save_steps is not None
|
97 |
+
and save_steps % eval_steps == 0
|
98 |
+
and cfg.load_in_8bit is not True
|
99 |
else False,
|
100 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
101 |
group_by_length=cfg.group_by_length,
|
102 |
report_to="wandb" if cfg.use_wandb else None,
|
103 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
104 |
optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
|
105 |
+
lr_scheduler_type=cfg.lr_scheduler
|
106 |
+
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
107 |
+
else "cosine",
|
108 |
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
109 |
**training_arguments_kwargs,
|
110 |
)
|
|
|
157 |
cfg.learning_rate,
|
158 |
total_steps=total_num_steps,
|
159 |
epochs=cfg.num_epochs,
|
160 |
+
div_factor=10,
|
161 |
**lr_scheduler_kwargs,
|
162 |
)
|
163 |
elif cfg.lr_scheduler == "log_sweep":
|
|
|
191 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
192 |
|
193 |
callbacks = []
|
194 |
+
if cfg.adapter == "lora":
|
195 |
callbacks.append(SavePeftModelCallback)
|
196 |
|
197 |
trainer = transformers.Trainer(
|
src/axolotl/utils/wandb.py
CHANGED
@@ -2,7 +2,9 @@ import os
|
|
2 |
|
3 |
|
4 |
def setup_wandb_env_vars(cfg):
|
5 |
-
if cfg.
|
|
|
|
|
6 |
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
7 |
cfg.use_wandb = True
|
8 |
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
|
|
2 |
|
3 |
|
4 |
def setup_wandb_env_vars(cfg):
|
5 |
+
if cfg.wandb_mode and cfg.wandb_mode == "offline":
|
6 |
+
os.environ["WANDB_MODE"] = cfg.wandb_mode
|
7 |
+
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
8 |
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
9 |
cfg.use_wandb = True
|
10 |
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|