config chooser, update readme instructions, device config, llama flash attention, debug out the labels, fix config key checks, other bugfixes
Browse files- README.md +21 -1
- configs/cerebras_1_3B_alpaca.yml +38 -0
- requirements.txt +6 -0
- scripts/finetune.py +95 -7
- setup.cfg +6 -0
- src/axolotl/flash_attn.py +116 -0
- src/axolotl/prompt_tokenizers.py +1 -1
README.md
CHANGED
@@ -2,8 +2,18 @@
|
|
2 |
|
3 |
#### You know you're going to axolotl questions
|
4 |
|
|
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
```shell
|
9 |
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
|
@@ -11,3 +21,13 @@ python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json >
|
|
11 |
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
12 |
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
13 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
#### You know you're going to axolotl questions
|
4 |
|
5 |
+
## Getting Started
|
6 |
|
7 |
+
- Download some datasets.
|
8 |
+
|
9 |
+
```shell
|
10 |
+
curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o data/raw/alpaca_data_gpt4.json
|
11 |
+
curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o data/raw/vicuna_cleaned.json
|
12 |
+
curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o data/raw/gpt4-instruct-similarity-0.6-dataset.json
|
13 |
+
curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o data/raw/roleplay-similarity_0.6-instruct-dataset.json
|
14 |
+
```
|
15 |
+
|
16 |
+
- Convert the JSON data files to JSONL.
|
17 |
|
18 |
```shell
|
19 |
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
|
|
|
21 |
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
22 |
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
23 |
```
|
24 |
+
|
25 |
+
- Using JSONL makes it easier to subset the data if you want a smaller training set, i.e get 2000 random examples.
|
26 |
+
|
27 |
+
```shell
|
28 |
+
shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
|
29 |
+
```
|
30 |
+
|
31 |
+
- Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
|
32 |
+
- Install python dependencies `pip3 install -r requirements.txt`
|
33 |
+
- Train! `python3 scripts/finetune.py`, make sure to choose the correct YAML config file
|
configs/cerebras_1_3B_alpaca.yml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: cerebras/Cerebras-GPT-1.3B
|
2 |
+
model_type: AutoModelForCausalLM
|
3 |
+
tokenizer_type: AutoTokenizer
|
4 |
+
load_in_8bit: true
|
5 |
+
datasets:
|
6 |
+
- path: data/alpaca_data_gpt4.jsonl
|
7 |
+
type: alpaca
|
8 |
+
- path: data/vicuna_cleaned.jsonl
|
9 |
+
type: sharegpt
|
10 |
+
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
11 |
+
type: gpteacher
|
12 |
+
- path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
13 |
+
type: gpteacher
|
14 |
+
val_set_size: 0.05
|
15 |
+
adapter: lora
|
16 |
+
sequence_len: 2048
|
17 |
+
lora_r: 8
|
18 |
+
lora_alpha: 16
|
19 |
+
lora_dropout: 0.05
|
20 |
+
lora_target_modules:
|
21 |
+
- c_attn
|
22 |
+
lora_fan_in_fan_out: false
|
23 |
+
wandb_project: pythia-1.4b-lora
|
24 |
+
wandb_watch:
|
25 |
+
wandb_run_name:
|
26 |
+
wandb_log_model: checkpoint
|
27 |
+
output_dir: ./lora-alpaca
|
28 |
+
batch_size: 32
|
29 |
+
micro_batch_size: 4
|
30 |
+
num_epochs: 5
|
31 |
+
learning_rate: 0.0003
|
32 |
+
train_on_inputs: false
|
33 |
+
group_by_length: false
|
34 |
+
bf16: True
|
35 |
+
tf32: True
|
36 |
+
resume_from_checkpoint:
|
37 |
+
local_rank:
|
38 |
+
deepspeed:
|
requirements.txt
CHANGED
@@ -4,3 +4,9 @@ attrdict
|
|
4 |
fire
|
5 |
PyYAML==6.0
|
6 |
black
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
fire
|
5 |
PyYAML==6.0
|
6 |
black
|
7 |
+
bitsandbytes
|
8 |
+
datasets
|
9 |
+
accelerate
|
10 |
+
sentencepiece
|
11 |
+
wandb
|
12 |
+
flash-attn
|
scripts/finetune.py
CHANGED
@@ -9,7 +9,7 @@ import fire
|
|
9 |
import torch
|
10 |
import transformers
|
11 |
import yaml
|
12 |
-
from attrdict import
|
13 |
from datasets import load_dataset, IterableDataset, Dataset
|
14 |
from peft import (
|
15 |
LoraConfig,
|
@@ -50,6 +50,11 @@ def setup_wandb_env_vars(cfg):
|
|
50 |
def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
51 |
if adapter != "lora":
|
52 |
raise NotImplementedError(f"{adapter} peft adapter not available")
|
|
|
|
|
|
|
|
|
|
|
53 |
try:
|
54 |
model = getattr(transformers, model_type).from_pretrained(
|
55 |
base_model,
|
@@ -99,24 +104,104 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
|
99 |
return model, tokenizer, lora_config
|
100 |
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def train(
|
103 |
-
config: Path = Path("configs/
|
104 |
**kwargs,
|
105 |
):
|
|
|
|
|
|
|
106 |
# load the config from the yaml file
|
107 |
with open(config, "r") as f:
|
108 |
-
cfg:
|
109 |
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
110 |
# then overwrite the value
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
# setup some derived config / hyperparams
|
116 |
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
|
117 |
-
cfg.device_map = "auto"
|
118 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
119 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
|
120 |
cfg.ddp = cfg.world_size != 1
|
121 |
if cfg.ddp:
|
122 |
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
@@ -163,6 +248,8 @@ def train(
|
|
163 |
train_dataset = constant_len_dataset["train"]
|
164 |
eval_dataset = constant_len_dataset["test"]
|
165 |
|
|
|
|
|
166 |
total_num_steps = int(
|
167 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
168 |
)
|
@@ -240,6 +327,7 @@ def train(
|
|
240 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
241 |
model = torch.compile(model)
|
242 |
|
|
|
243 |
signal.signal(
|
244 |
signal.SIGINT,
|
245 |
lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
|
|
|
9 |
import torch
|
10 |
import transformers
|
11 |
import yaml
|
12 |
+
from attrdict import AttrDefault
|
13 |
from datasets import load_dataset, IterableDataset, Dataset
|
14 |
from peft import (
|
15 |
LoraConfig,
|
|
|
50 |
def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
51 |
if adapter != "lora":
|
52 |
raise NotImplementedError(f"{adapter} peft adapter not available")
|
53 |
+
if "llama" in base_model:
|
54 |
+
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
55 |
+
|
56 |
+
replace_llama_attn_with_flash_attn()
|
57 |
+
|
58 |
try:
|
59 |
model = getattr(transformers, model_type).from_pretrained(
|
60 |
base_model,
|
|
|
104 |
return model, tokenizer, lora_config
|
105 |
|
106 |
|
107 |
+
def choose_device(cfg):
|
108 |
+
def get_device():
|
109 |
+
if torch.cuda.is_available():
|
110 |
+
return "cuda"
|
111 |
+
else:
|
112 |
+
try:
|
113 |
+
if torch.backends.mps.is_available():
|
114 |
+
return "mps"
|
115 |
+
except:
|
116 |
+
return "cpu"
|
117 |
+
|
118 |
+
cfg.device = get_device()
|
119 |
+
if cfg.device == "cuda":
|
120 |
+
cfg.device_map = {"": cfg.local_rank}
|
121 |
+
else:
|
122 |
+
cfg.device_map = {"": cfg.device}
|
123 |
+
|
124 |
+
|
125 |
+
def check_dataset_labels(dataset, tokenizer):
|
126 |
+
from termcolor import colored
|
127 |
+
|
128 |
+
# the dataset is already shuffled, so let's just check the first 5 elements
|
129 |
+
for idx in range(5):
|
130 |
+
# Get the input_ids, labels, and attention_mask from the dataset
|
131 |
+
input_ids = dataset[idx]["input_ids"]
|
132 |
+
labels = dataset[idx]["labels"]
|
133 |
+
attention_mask = dataset[idx]["attention_mask"]
|
134 |
+
|
135 |
+
# You can compare the input_ids and labels element-wise
|
136 |
+
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
|
137 |
+
colored_tokens = []
|
138 |
+
for i, (input_id, label_id, mask) in enumerate(
|
139 |
+
zip(input_ids, labels, attention_mask)
|
140 |
+
):
|
141 |
+
decoded_input_token = tokenizer.decode(input_id)
|
142 |
+
# Choose the color based on whether the label has the ignore value or not
|
143 |
+
color = (
|
144 |
+
"red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
|
145 |
+
)
|
146 |
+
colored_token = colored(decoded_input_token, color) + colored(
|
147 |
+
f"({label_id}, {mask})", "white"
|
148 |
+
)
|
149 |
+
colored_tokens.append(colored_token)
|
150 |
+
|
151 |
+
print(" ".join(colored_tokens))
|
152 |
+
print("\n\n\n")
|
153 |
+
|
154 |
+
|
155 |
+
def choose_config(path: Path):
|
156 |
+
yaml_files = [file for file in path.glob("*.yml")]
|
157 |
+
|
158 |
+
if not yaml_files:
|
159 |
+
raise ValueError("No YAML config files found in the specified directory. Are you using a .yml extension?")
|
160 |
+
|
161 |
+
print("Choose a YAML file:")
|
162 |
+
for idx, file in enumerate(yaml_files):
|
163 |
+
print(f"{idx + 1}. {file}")
|
164 |
+
|
165 |
+
chosen_file = None
|
166 |
+
while chosen_file is None:
|
167 |
+
try:
|
168 |
+
choice = int(input("Enter the number of your choice: "))
|
169 |
+
if 1 <= choice <= len(yaml_files):
|
170 |
+
chosen_file = yaml_files[choice - 1]
|
171 |
+
else:
|
172 |
+
print("Invalid choice. Please choose a number from the list.")
|
173 |
+
except ValueError:
|
174 |
+
print("Invalid input. Please enter a number.")
|
175 |
+
|
176 |
+
return chosen_file
|
177 |
+
|
178 |
+
|
179 |
def train(
|
180 |
+
config: Path = Path("configs/"),
|
181 |
**kwargs,
|
182 |
):
|
183 |
+
if config.is_dir():
|
184 |
+
config = choose_config(config)
|
185 |
+
|
186 |
# load the config from the yaml file
|
187 |
with open(config, "r") as f:
|
188 |
+
cfg: AttrDefault = AttrDefault(lambda: None, yaml.load(f, Loader=yaml.Loader))
|
189 |
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
190 |
# then overwrite the value
|
191 |
+
cfg_keys = dict(cfg).keys()
|
192 |
+
for k in kwargs:
|
193 |
+
if k in cfg_keys:
|
194 |
+
# handle booleans
|
195 |
+
if isinstance(cfg[k], bool):
|
196 |
+
cfg[k] = bool(kwargs[k])
|
197 |
+
else:
|
198 |
+
cfg[k] = kwargs[k]
|
199 |
|
200 |
# setup some derived config / hyperparams
|
201 |
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
|
|
|
202 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
203 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
204 |
+
choose_device(cfg)
|
205 |
cfg.ddp = cfg.world_size != 1
|
206 |
if cfg.ddp:
|
207 |
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
|
|
248 |
train_dataset = constant_len_dataset["train"]
|
249 |
eval_dataset = constant_len_dataset["test"]
|
250 |
|
251 |
+
# check_dataset_labels(eval_dataset, tokenizer)
|
252 |
+
|
253 |
total_num_steps = int(
|
254 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
255 |
)
|
|
|
327 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
328 |
model = torch.compile(model)
|
329 |
|
330 |
+
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
331 |
signal.signal(
|
332 |
signal.SIGINT,
|
333 |
lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
|
setup.cfg
CHANGED
@@ -17,6 +17,12 @@ install_requires =
|
|
17 |
fire
|
18 |
PyYAML == 6.0
|
19 |
black
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
[options.packages.find]
|
22 |
where = src
|
|
|
17 |
fire
|
18 |
PyYAML == 6.0
|
19 |
black
|
20 |
+
bitsandbytes
|
21 |
+
datasets
|
22 |
+
accelerate
|
23 |
+
sentencepiece
|
24 |
+
wandb
|
25 |
+
flash-attn
|
26 |
|
27 |
[options.packages.find]
|
28 |
where = src
|
src/axolotl/flash_attn.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
2 |
+
|
3 |
+
from typing import List, Optional, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
import transformers
|
9 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
10 |
+
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
14 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
15 |
+
|
16 |
+
|
17 |
+
def forward(
|
18 |
+
self,
|
19 |
+
hidden_states: torch.Tensor,
|
20 |
+
attention_mask: Optional[torch.Tensor] = None,
|
21 |
+
position_ids: Optional[torch.Tensor] = None,
|
22 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
23 |
+
output_attentions: bool = False,
|
24 |
+
use_cache: bool = False,
|
25 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
26 |
+
"""Input shape: Batch x Time x Channel
|
27 |
+
|
28 |
+
attention_mask: [bsz, q_len]
|
29 |
+
"""
|
30 |
+
bsz, q_len, _ = hidden_states.size()
|
31 |
+
|
32 |
+
query_states = (
|
33 |
+
self.q_proj(hidden_states)
|
34 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
35 |
+
.transpose(1, 2)
|
36 |
+
)
|
37 |
+
key_states = (
|
38 |
+
self.k_proj(hidden_states)
|
39 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
40 |
+
.transpose(1, 2)
|
41 |
+
)
|
42 |
+
value_states = (
|
43 |
+
self.v_proj(hidden_states)
|
44 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
45 |
+
.transpose(1, 2)
|
46 |
+
)
|
47 |
+
# [bsz, q_len, nh, hd]
|
48 |
+
# [bsz, nh, q_len, hd]
|
49 |
+
|
50 |
+
kv_seq_len = key_states.shape[-2]
|
51 |
+
assert past_key_value is None, "past_key_value is not supported"
|
52 |
+
|
53 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
54 |
+
query_states, key_states = apply_rotary_pos_emb(
|
55 |
+
query_states, key_states, cos, sin, position_ids
|
56 |
+
)
|
57 |
+
# [bsz, nh, t, hd]
|
58 |
+
assert not output_attentions, "output_attentions is not supported"
|
59 |
+
assert not use_cache, "use_cache is not supported"
|
60 |
+
|
61 |
+
# Flash attention codes from
|
62 |
+
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
63 |
+
|
64 |
+
# transform the data into the format required by flash attention
|
65 |
+
qkv = torch.stack(
|
66 |
+
[query_states, key_states, value_states], dim=2
|
67 |
+
) # [bsz, nh, 3, q_len, hd]
|
68 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
69 |
+
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
70 |
+
# the attention_mask should be the same as the key_padding_mask
|
71 |
+
key_padding_mask = attention_mask
|
72 |
+
|
73 |
+
if key_padding_mask is None:
|
74 |
+
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
75 |
+
max_s = q_len
|
76 |
+
cu_q_lens = torch.arange(
|
77 |
+
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
|
78 |
+
)
|
79 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
80 |
+
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
81 |
+
)
|
82 |
+
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
83 |
+
else:
|
84 |
+
nheads = qkv.shape[-2]
|
85 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
86 |
+
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
87 |
+
x_unpad = rearrange(
|
88 |
+
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
89 |
+
)
|
90 |
+
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
91 |
+
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
92 |
+
)
|
93 |
+
output = rearrange(
|
94 |
+
pad_input(
|
95 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
|
96 |
+
),
|
97 |
+
"b s (h d) -> b s h d",
|
98 |
+
h=nheads,
|
99 |
+
)
|
100 |
+
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
|
101 |
+
|
102 |
+
|
103 |
+
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
104 |
+
# requires the attention mask to be the same as the key_padding_mask
|
105 |
+
def _prepare_decoder_attention_mask(
|
106 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
107 |
+
):
|
108 |
+
# [bsz, seq_len]
|
109 |
+
return attention_mask
|
110 |
+
|
111 |
+
|
112 |
+
def replace_llama_attn_with_flash_attn():
|
113 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
114 |
+
_prepare_decoder_attention_mask
|
115 |
+
)
|
116 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -88,5 +88,5 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
88 |
def tokenize_prompt(self, prompt):
|
89 |
try:
|
90 |
return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
|
91 |
-
except (KeyError, AssertionError) as e:
|
92 |
raise InvalidDataException(str(e))
|
|
|
88 |
def tokenize_prompt(self, prompt):
|
89 |
try:
|
90 |
return self.prompter.build_prompt(prompt["conversations"], self.tokenizer)
|
91 |
+
except (KeyError, AssertionError, IndexError) as e:
|
92 |
raise InvalidDataException(str(e))
|