Merge pull request #124 from OpenAccess-AI-Collective/xformers-fix
Browse filescopy xformers attn from ooba since we removed dep on alpaca_lora_4bit
- .mypy.ini +6 -0
- README.md +3 -0
- src/axolotl/flash_attn.py +1 -0
- src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +233 -0
- src/axolotl/utils/models.py +8 -1
.mypy.ini
CHANGED
@@ -5,6 +5,9 @@ exclude = venv
|
|
5 |
[mypy-alpaca_lora_4bit.*]
|
6 |
ignore_missing_imports = True
|
7 |
|
|
|
|
|
|
|
8 |
[mypy-flash_attn.*]
|
9 |
ignore_missing_imports = True
|
10 |
|
@@ -31,3 +34,6 @@ ignore_missing_imports = True
|
|
31 |
|
32 |
[mypy-addict]
|
33 |
ignore_missing_imports = True
|
|
|
|
|
|
|
|
5 |
[mypy-alpaca_lora_4bit.*]
|
6 |
ignore_missing_imports = True
|
7 |
|
8 |
+
[mypy-axolotl.monkeypatch.*]
|
9 |
+
ignore_errors = True
|
10 |
+
|
11 |
[mypy-flash_attn.*]
|
12 |
ignore_missing_imports = True
|
13 |
|
|
|
34 |
|
35 |
[mypy-addict]
|
36 |
ignore_missing_imports = True
|
37 |
+
|
38 |
+
[mypy-xformers.*]
|
39 |
+
ignore_missing_imports = True
|
README.md
CHANGED
@@ -303,6 +303,9 @@ weight_decay:
|
|
303 |
xformers_attention:
|
304 |
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
305 |
flash_attention: # require a100 for llama
|
|
|
|
|
|
|
306 |
|
307 |
# resume from a specific checkpoint dir
|
308 |
resume_from_checkpoint:
|
|
|
303 |
xformers_attention:
|
304 |
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
305 |
flash_attention: # require a100 for llama
|
306 |
+
# whether to use scaled-dot-product attention
|
307 |
+
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
308 |
+
sdp_attention:
|
309 |
|
310 |
# resume from a specific checkpoint dir
|
311 |
resume_from_checkpoint:
|
src/axolotl/flash_attn.py
CHANGED
@@ -25,6 +25,7 @@ def forward(
|
|
25 |
|
26 |
attention_mask: [bsz, q_len]
|
27 |
"""
|
|
|
28 |
bsz, q_len, _ = hidden_states.size()
|
29 |
|
30 |
query_states = (
|
|
|
25 |
|
26 |
attention_mask: [bsz, q_len]
|
27 |
"""
|
28 |
+
# pylint: disable=duplicate-code
|
29 |
bsz, q_len, _ = hidden_states.size()
|
30 |
|
31 |
query_states = (
|
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import math
|
7 |
+
from typing import Optional, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import transformers.models.llama.modeling_llama
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
try:
|
14 |
+
import xformers.ops
|
15 |
+
except ImportError:
|
16 |
+
logging.error("xformers not found! Please install it before trying to use it.")
|
17 |
+
|
18 |
+
|
19 |
+
def hijack_llama_attention():
|
20 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
21 |
+
|
22 |
+
|
23 |
+
def hijack_llama_sdp_attention():
|
24 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
25 |
+
sdp_attention_forward
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
def xformers_forward(
|
30 |
+
self,
|
31 |
+
hidden_states: torch.Tensor,
|
32 |
+
attention_mask: Optional[torch.Tensor] = None,
|
33 |
+
position_ids: Optional[torch.LongTensor] = None,
|
34 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
35 |
+
output_attentions: bool = False,
|
36 |
+
use_cache: bool = False,
|
37 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
38 |
+
# pylint: disable=duplicate-code
|
39 |
+
bsz, q_len, _ = hidden_states.size()
|
40 |
+
|
41 |
+
query_states = (
|
42 |
+
self.q_proj(hidden_states)
|
43 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
44 |
+
.transpose(1, 2)
|
45 |
+
)
|
46 |
+
key_states = (
|
47 |
+
self.k_proj(hidden_states)
|
48 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
49 |
+
.transpose(1, 2)
|
50 |
+
)
|
51 |
+
value_states = (
|
52 |
+
self.v_proj(hidden_states)
|
53 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
54 |
+
.transpose(1, 2)
|
55 |
+
)
|
56 |
+
|
57 |
+
kv_seq_len = key_states.shape[-2]
|
58 |
+
if past_key_value is not None:
|
59 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
60 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
61 |
+
(
|
62 |
+
query_states,
|
63 |
+
key_states,
|
64 |
+
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
65 |
+
query_states, key_states, cos, sin, position_ids
|
66 |
+
)
|
67 |
+
# [bsz, nh, t, hd]
|
68 |
+
|
69 |
+
if past_key_value is not None:
|
70 |
+
# reuse k, v, self_attention
|
71 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
72 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
73 |
+
|
74 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
75 |
+
|
76 |
+
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
77 |
+
if not output_attentions:
|
78 |
+
query_states = query_states.transpose(1, 2)
|
79 |
+
key_states = key_states.transpose(1, 2)
|
80 |
+
value_states = value_states.transpose(1, 2)
|
81 |
+
|
82 |
+
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
83 |
+
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
84 |
+
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
85 |
+
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
86 |
+
attn_output = xformers.ops.memory_efficient_attention(
|
87 |
+
query_states, key_states, value_states, attn_bias=None
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
91 |
+
attn_output = xformers.ops.memory_efficient_attention(
|
92 |
+
query_states,
|
93 |
+
key_states,
|
94 |
+
value_states,
|
95 |
+
attn_bias=xformers.ops.LowerTriangularMask(),
|
96 |
+
)
|
97 |
+
attn_weights = None
|
98 |
+
else:
|
99 |
+
attn_weights = torch.matmul(
|
100 |
+
query_states, key_states.transpose(2, 3)
|
101 |
+
) / math.sqrt(self.head_dim)
|
102 |
+
|
103 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
104 |
+
raise ValueError(
|
105 |
+
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
106 |
+
f" {attn_weights.size()}"
|
107 |
+
)
|
108 |
+
|
109 |
+
if attention_mask is not None:
|
110 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
111 |
+
raise ValueError(
|
112 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
113 |
+
)
|
114 |
+
attn_weights = attn_weights + attention_mask
|
115 |
+
attn_weights = torch.max(
|
116 |
+
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
117 |
+
)
|
118 |
+
|
119 |
+
# upcast attention to fp32
|
120 |
+
attn_weights = nn.functional.softmax(
|
121 |
+
attn_weights, dim=-1, dtype=torch.float32
|
122 |
+
).to(query_states.dtype)
|
123 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
124 |
+
|
125 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
126 |
+
raise ValueError(
|
127 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
128 |
+
f" {attn_output.size()}"
|
129 |
+
)
|
130 |
+
|
131 |
+
attn_output = attn_output.transpose(1, 2)
|
132 |
+
|
133 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
134 |
+
attn_output = self.o_proj(attn_output)
|
135 |
+
return attn_output, attn_weights, past_key_value
|
136 |
+
|
137 |
+
|
138 |
+
def sdp_attention_forward(
|
139 |
+
self,
|
140 |
+
hidden_states: torch.Tensor,
|
141 |
+
attention_mask: Optional[torch.Tensor] = None,
|
142 |
+
position_ids: Optional[torch.LongTensor] = None,
|
143 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
144 |
+
output_attentions: bool = False,
|
145 |
+
use_cache: bool = False,
|
146 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
147 |
+
# pylint: disable=duplicate-code
|
148 |
+
bsz, q_len, _ = hidden_states.size()
|
149 |
+
|
150 |
+
query_states = (
|
151 |
+
self.q_proj(hidden_states)
|
152 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
153 |
+
.transpose(1, 2)
|
154 |
+
)
|
155 |
+
key_states = (
|
156 |
+
self.k_proj(hidden_states)
|
157 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
158 |
+
.transpose(1, 2)
|
159 |
+
)
|
160 |
+
value_states = (
|
161 |
+
self.v_proj(hidden_states)
|
162 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
163 |
+
.transpose(1, 2)
|
164 |
+
)
|
165 |
+
|
166 |
+
kv_seq_len = key_states.shape[-2]
|
167 |
+
if past_key_value is not None:
|
168 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
169 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
170 |
+
(
|
171 |
+
query_states,
|
172 |
+
key_states,
|
173 |
+
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
174 |
+
query_states, key_states, cos, sin, position_ids
|
175 |
+
)
|
176 |
+
# [bsz, nh, t, hd]
|
177 |
+
|
178 |
+
if past_key_value is not None:
|
179 |
+
# reuse k, v, self_attention
|
180 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
181 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
182 |
+
|
183 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
184 |
+
|
185 |
+
# We only apply sdp attention if we don't need to output the whole attention matrix
|
186 |
+
if not output_attentions:
|
187 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
188 |
+
query_states,
|
189 |
+
key_states,
|
190 |
+
value_states,
|
191 |
+
attn_mask=attention_mask,
|
192 |
+
is_causal=False,
|
193 |
+
)
|
194 |
+
attn_weights = None
|
195 |
+
else:
|
196 |
+
attn_weights = torch.matmul(
|
197 |
+
query_states, key_states.transpose(2, 3)
|
198 |
+
) / math.sqrt(self.head_dim)
|
199 |
+
|
200 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
201 |
+
raise ValueError(
|
202 |
+
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
203 |
+
f" {attn_weights.size()}"
|
204 |
+
)
|
205 |
+
|
206 |
+
if attention_mask is not None:
|
207 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
208 |
+
raise ValueError(
|
209 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
210 |
+
)
|
211 |
+
attn_weights = attn_weights + attention_mask
|
212 |
+
attn_weights = torch.max(
|
213 |
+
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
214 |
+
)
|
215 |
+
|
216 |
+
# upcast attention to fp32
|
217 |
+
attn_weights = nn.functional.softmax(
|
218 |
+
attn_weights, dim=-1, dtype=torch.float32
|
219 |
+
).to(query_states.dtype)
|
220 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
221 |
+
|
222 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
223 |
+
raise ValueError(
|
224 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
225 |
+
f" {attn_output.size()}"
|
226 |
+
)
|
227 |
+
|
228 |
+
attn_output = attn_output.transpose(1, 2)
|
229 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
230 |
+
|
231 |
+
attn_output = self.o_proj(attn_output)
|
232 |
+
|
233 |
+
return attn_output, attn_weights, past_key_value
|
src/axolotl/utils/models.py
CHANGED
@@ -101,12 +101,19 @@ def load_model(
|
|
101 |
logging.info("patching with flash attention")
|
102 |
replace_llama_attn_with_flash_attn()
|
103 |
elif is_llama_derived_model and cfg.xformers_attention:
|
104 |
-
from
|
105 |
hijack_llama_attention,
|
106 |
)
|
107 |
|
108 |
logging.info("patching with xformers attention")
|
109 |
hijack_llama_attention()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
if cfg.bf16:
|
112 |
torch_dtype = torch.bfloat16
|
|
|
101 |
logging.info("patching with flash attention")
|
102 |
replace_llama_attn_with_flash_attn()
|
103 |
elif is_llama_derived_model and cfg.xformers_attention:
|
104 |
+
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
105 |
hijack_llama_attention,
|
106 |
)
|
107 |
|
108 |
logging.info("patching with xformers attention")
|
109 |
hijack_llama_attention()
|
110 |
+
elif is_llama_derived_model and cfg.sdp_attention:
|
111 |
+
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
112 |
+
hijack_llama_sdp_attention,
|
113 |
+
)
|
114 |
+
|
115 |
+
logging.info("patching with sdp attention")
|
116 |
+
hijack_llama_sdp_attention()
|
117 |
|
118 |
if cfg.bf16:
|
119 |
torch_dtype = torch.bfloat16
|