various bugfixes
Browse files- configs/llama_65B_alpaca.yml +3 -3
- requirements.txt +3 -0
- scripts/finetune.py +8 -7
- setup.cfg +1 -0
- src/axolotl/datasets.py +21 -19
configs/llama_65B_alpaca.yml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
base_model:
|
2 |
model_type: LlamaForCausalLM
|
3 |
tokenizer_type: LlamaTokenizer
|
4 |
load_in_8bit: true
|
@@ -33,8 +33,8 @@ num_epochs: 5
|
|
33 |
learning_rate: 0.00003
|
34 |
train_on_inputs: false
|
35 |
group_by_length: false
|
36 |
-
bf16:
|
37 |
-
tf32:
|
38 |
resume_from_checkpoint:
|
39 |
local_rank:
|
40 |
deepspeed:
|
|
|
1 |
+
base_model: huggyllama/llama-7b
|
2 |
model_type: LlamaForCausalLM
|
3 |
tokenizer_type: LlamaTokenizer
|
4 |
load_in_8bit: true
|
|
|
33 |
learning_rate: 0.00003
|
34 |
train_on_inputs: false
|
35 |
group_by_length: false
|
36 |
+
bf16: true
|
37 |
+
tf32: true
|
38 |
resume_from_checkpoint:
|
39 |
local_rank:
|
40 |
deepspeed:
|
requirements.txt
CHANGED
@@ -10,3 +10,6 @@ accelerate
|
|
10 |
sentencepiece
|
11 |
wandb
|
12 |
flash-attn
|
|
|
|
|
|
|
|
10 |
sentencepiece
|
11 |
wandb
|
12 |
flash-attn
|
13 |
+
deepspeed
|
14 |
+
einops
|
15 |
+
|
scripts/finetune.py
CHANGED
@@ -11,7 +11,7 @@ import torch
|
|
11 |
import transformers
|
12 |
import yaml
|
13 |
from attrdict import AttrDefault
|
14 |
-
from datasets import load_dataset, IterableDataset, Dataset
|
15 |
from peft import (
|
16 |
LoraConfig,
|
17 |
get_peft_model,
|
@@ -52,8 +52,9 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
|
52 |
if adapter != "lora":
|
53 |
raise NotImplementedError(f"{adapter} peft adapter not available")
|
54 |
if "llama" in base_model:
|
55 |
-
|
56 |
-
|
|
|
57 |
|
58 |
try:
|
59 |
if "llama" in base_model:
|
@@ -86,7 +87,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
|
86 |
except:
|
87 |
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
88 |
|
89 |
-
if tokenizer.__class__.__name__
|
90 |
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
91 |
|
92 |
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
@@ -255,8 +256,9 @@ def train(
|
|
255 |
return
|
256 |
|
257 |
datasets = []
|
258 |
-
if
|
259 |
-
|
|
|
260 |
else:
|
261 |
for d in cfg.datasets:
|
262 |
ds: IterableDataset = load_dataset(
|
@@ -288,7 +290,6 @@ def train(
|
|
288 |
[_ for _ in constant_len_dataset]
|
289 |
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
290 |
dataset.save_to_disk("data/last_run")
|
291 |
-
print(dataset)
|
292 |
|
293 |
train_dataset = dataset["train"]
|
294 |
eval_dataset = dataset["test"]
|
|
|
11 |
import transformers
|
12 |
import yaml
|
13 |
from attrdict import AttrDefault
|
14 |
+
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
|
15 |
from peft import (
|
16 |
LoraConfig,
|
17 |
get_peft_model,
|
|
|
52 |
if adapter != "lora":
|
53 |
raise NotImplementedError(f"{adapter} peft adapter not available")
|
54 |
if "llama" in base_model:
|
55 |
+
if cfg.device not in ["mps", "cpu"]:
|
56 |
+
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
57 |
+
replace_llama_attn_with_flash_attn()
|
58 |
|
59 |
try:
|
60 |
if "llama" in base_model:
|
|
|
87 |
except:
|
88 |
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
89 |
|
90 |
+
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
91 |
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
92 |
|
93 |
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
|
|
256 |
return
|
257 |
|
258 |
datasets = []
|
259 |
+
if not isinstance(cfg.datasets, list) and isinstance(cfg.datasets, str):
|
260 |
+
# assumption that we are loading a previously saved/cached dataset
|
261 |
+
dataset = load_from_disk(cfg.datasets)
|
262 |
else:
|
263 |
for d in cfg.datasets:
|
264 |
ds: IterableDataset = load_dataset(
|
|
|
290 |
[_ for _ in constant_len_dataset]
|
291 |
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
292 |
dataset.save_to_disk("data/last_run")
|
|
|
293 |
|
294 |
train_dataset = dataset["train"]
|
295 |
eval_dataset = dataset["test"]
|
setup.cfg
CHANGED
@@ -23,6 +23,7 @@ install_requires =
|
|
23 |
sentencepiece
|
24 |
wandb
|
25 |
flash-attn
|
|
|
26 |
|
27 |
[options.packages.find]
|
28 |
where = src
|
|
|
23 |
sentencepiece
|
24 |
wandb
|
25 |
flash-attn
|
26 |
+
einops
|
27 |
|
28 |
[options.packages.find]
|
29 |
where = src
|
src/axolotl/datasets.py
CHANGED
@@ -93,22 +93,24 @@ class ConstantLengthDataset(IterableDataset):
|
|
93 |
buffer_len = 0
|
94 |
|
95 |
if example:
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
93 |
buffer_len = 0
|
94 |
|
95 |
if example:
|
96 |
+
# just going to drop data points that are too long
|
97 |
+
if len(example["input_ids"]) <= self.seq_length:
|
98 |
+
input_ids = example["input_ids"]
|
99 |
+
attention_mask = example["attention_mask"]
|
100 |
+
labels = example["labels"]
|
101 |
+
|
102 |
+
if add_concat_token:
|
103 |
+
input_ids.append(self.concat_token_id)
|
104 |
+
attention_mask.append(1)
|
105 |
+
labels.append(self.concat_token_id)
|
106 |
+
|
107 |
+
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
|
108 |
+
attention_mask_with_concat = torch.tensor(
|
109 |
+
attention_mask, dtype=torch.long
|
110 |
+
)
|
111 |
+
labels_with_concat = torch.tensor(labels, dtype=torch.long)
|
112 |
+
|
113 |
+
buffer["input_ids"].append(input_ids_with_concat)
|
114 |
+
buffer["attention_mask"].append(attention_mask_with_concat)
|
115 |
+
buffer["labels"].append(labels_with_concat)
|
116 |
+
buffer_len += len(input_ids)
|