winglian commited on
Commit
2d0ba3b
·
unverified ·
2 Parent(s): c7021e1 c56818b

Merge pull request #124 from OpenAccess-AI-Collective/xformers-fix

Browse files

copy xformers attn from ooba since we removed dep on alpaca_lora_4bit

.mypy.ini CHANGED
@@ -5,6 +5,9 @@ exclude = venv
5
  [mypy-alpaca_lora_4bit.*]
6
  ignore_missing_imports = True
7
 
 
 
 
8
  [mypy-flash_attn.*]
9
  ignore_missing_imports = True
10
 
@@ -31,3 +34,6 @@ ignore_missing_imports = True
31
 
32
  [mypy-addict]
33
  ignore_missing_imports = True
 
 
 
 
5
  [mypy-alpaca_lora_4bit.*]
6
  ignore_missing_imports = True
7
 
8
+ [mypy-axolotl.monkeypatch.*]
9
+ ignore_errors = True
10
+
11
  [mypy-flash_attn.*]
12
  ignore_missing_imports = True
13
 
 
34
 
35
  [mypy-addict]
36
  ignore_missing_imports = True
37
+
38
+ [mypy-xformers.*]
39
+ ignore_missing_imports = True
README.md CHANGED
@@ -303,6 +303,9 @@ weight_decay:
303
  xformers_attention:
304
  # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
305
  flash_attention: # require a100 for llama
 
 
 
306
 
307
  # resume from a specific checkpoint dir
308
  resume_from_checkpoint:
 
303
  xformers_attention:
304
  # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
305
  flash_attention: # require a100 for llama
306
+ # whether to use scaled-dot-product attention
307
+ # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
308
+ sdp_attention:
309
 
310
  # resume from a specific checkpoint dir
311
  resume_from_checkpoint:
src/axolotl/flash_attn.py CHANGED
@@ -25,6 +25,7 @@ def forward(
25
 
26
  attention_mask: [bsz, q_len]
27
  """
 
28
  bsz, q_len, _ = hidden_states.size()
29
 
30
  query_states = (
 
25
 
26
  attention_mask: [bsz, q_len]
27
  """
28
+ # pylint: disable=duplicate-code
29
  bsz, q_len, _ = hidden_states.size()
30
 
31
  query_states = (
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
3
+ """
4
+
5
+ import logging
6
+ import math
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import transformers.models.llama.modeling_llama
11
+ from torch import nn
12
+
13
+ try:
14
+ import xformers.ops
15
+ except ImportError:
16
+ logging.error("xformers not found! Please install it before trying to use it.")
17
+
18
+
19
+ def hijack_llama_attention():
20
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
21
+
22
+
23
+ def hijack_llama_sdp_attention():
24
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = (
25
+ sdp_attention_forward
26
+ )
27
+
28
+
29
+ def xformers_forward(
30
+ self,
31
+ hidden_states: torch.Tensor,
32
+ attention_mask: Optional[torch.Tensor] = None,
33
+ position_ids: Optional[torch.LongTensor] = None,
34
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
35
+ output_attentions: bool = False,
36
+ use_cache: bool = False,
37
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
38
+ # pylint: disable=duplicate-code
39
+ bsz, q_len, _ = hidden_states.size()
40
+
41
+ query_states = (
42
+ self.q_proj(hidden_states)
43
+ .view(bsz, q_len, self.num_heads, self.head_dim)
44
+ .transpose(1, 2)
45
+ )
46
+ key_states = (
47
+ self.k_proj(hidden_states)
48
+ .view(bsz, q_len, self.num_heads, self.head_dim)
49
+ .transpose(1, 2)
50
+ )
51
+ value_states = (
52
+ self.v_proj(hidden_states)
53
+ .view(bsz, q_len, self.num_heads, self.head_dim)
54
+ .transpose(1, 2)
55
+ )
56
+
57
+ kv_seq_len = key_states.shape[-2]
58
+ if past_key_value is not None:
59
+ kv_seq_len += past_key_value[0].shape[-2]
60
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
61
+ (
62
+ query_states,
63
+ key_states,
64
+ ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
65
+ query_states, key_states, cos, sin, position_ids
66
+ )
67
+ # [bsz, nh, t, hd]
68
+
69
+ if past_key_value is not None:
70
+ # reuse k, v, self_attention
71
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
72
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
73
+
74
+ past_key_value = (key_states, value_states) if use_cache else None
75
+
76
+ # We only apply xformers optimizations if we don't need to output the whole attention matrix
77
+ if not output_attentions:
78
+ query_states = query_states.transpose(1, 2)
79
+ key_states = key_states.transpose(1, 2)
80
+ value_states = value_states.transpose(1, 2)
81
+
82
+ # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
83
+ # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
84
+ if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
85
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
86
+ attn_output = xformers.ops.memory_efficient_attention(
87
+ query_states, key_states, value_states, attn_bias=None
88
+ )
89
+ else:
90
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
91
+ attn_output = xformers.ops.memory_efficient_attention(
92
+ query_states,
93
+ key_states,
94
+ value_states,
95
+ attn_bias=xformers.ops.LowerTriangularMask(),
96
+ )
97
+ attn_weights = None
98
+ else:
99
+ attn_weights = torch.matmul(
100
+ query_states, key_states.transpose(2, 3)
101
+ ) / math.sqrt(self.head_dim)
102
+
103
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
104
+ raise ValueError(
105
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
106
+ f" {attn_weights.size()}"
107
+ )
108
+
109
+ if attention_mask is not None:
110
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
111
+ raise ValueError(
112
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
113
+ )
114
+ attn_weights = attn_weights + attention_mask
115
+ attn_weights = torch.max(
116
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
117
+ )
118
+
119
+ # upcast attention to fp32
120
+ attn_weights = nn.functional.softmax(
121
+ attn_weights, dim=-1, dtype=torch.float32
122
+ ).to(query_states.dtype)
123
+ attn_output = torch.matmul(attn_weights, value_states)
124
+
125
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
126
+ raise ValueError(
127
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
128
+ f" {attn_output.size()}"
129
+ )
130
+
131
+ attn_output = attn_output.transpose(1, 2)
132
+
133
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
134
+ attn_output = self.o_proj(attn_output)
135
+ return attn_output, attn_weights, past_key_value
136
+
137
+
138
+ def sdp_attention_forward(
139
+ self,
140
+ hidden_states: torch.Tensor,
141
+ attention_mask: Optional[torch.Tensor] = None,
142
+ position_ids: Optional[torch.LongTensor] = None,
143
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
144
+ output_attentions: bool = False,
145
+ use_cache: bool = False,
146
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
147
+ # pylint: disable=duplicate-code
148
+ bsz, q_len, _ = hidden_states.size()
149
+
150
+ query_states = (
151
+ self.q_proj(hidden_states)
152
+ .view(bsz, q_len, self.num_heads, self.head_dim)
153
+ .transpose(1, 2)
154
+ )
155
+ key_states = (
156
+ self.k_proj(hidden_states)
157
+ .view(bsz, q_len, self.num_heads, self.head_dim)
158
+ .transpose(1, 2)
159
+ )
160
+ value_states = (
161
+ self.v_proj(hidden_states)
162
+ .view(bsz, q_len, self.num_heads, self.head_dim)
163
+ .transpose(1, 2)
164
+ )
165
+
166
+ kv_seq_len = key_states.shape[-2]
167
+ if past_key_value is not None:
168
+ kv_seq_len += past_key_value[0].shape[-2]
169
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
170
+ (
171
+ query_states,
172
+ key_states,
173
+ ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
174
+ query_states, key_states, cos, sin, position_ids
175
+ )
176
+ # [bsz, nh, t, hd]
177
+
178
+ if past_key_value is not None:
179
+ # reuse k, v, self_attention
180
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
181
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
182
+
183
+ past_key_value = (key_states, value_states) if use_cache else None
184
+
185
+ # We only apply sdp attention if we don't need to output the whole attention matrix
186
+ if not output_attentions:
187
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
188
+ query_states,
189
+ key_states,
190
+ value_states,
191
+ attn_mask=attention_mask,
192
+ is_causal=False,
193
+ )
194
+ attn_weights = None
195
+ else:
196
+ attn_weights = torch.matmul(
197
+ query_states, key_states.transpose(2, 3)
198
+ ) / math.sqrt(self.head_dim)
199
+
200
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
201
+ raise ValueError(
202
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
203
+ f" {attn_weights.size()}"
204
+ )
205
+
206
+ if attention_mask is not None:
207
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
208
+ raise ValueError(
209
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
210
+ )
211
+ attn_weights = attn_weights + attention_mask
212
+ attn_weights = torch.max(
213
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
214
+ )
215
+
216
+ # upcast attention to fp32
217
+ attn_weights = nn.functional.softmax(
218
+ attn_weights, dim=-1, dtype=torch.float32
219
+ ).to(query_states.dtype)
220
+ attn_output = torch.matmul(attn_weights, value_states)
221
+
222
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
223
+ raise ValueError(
224
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
225
+ f" {attn_output.size()}"
226
+ )
227
+
228
+ attn_output = attn_output.transpose(1, 2)
229
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
230
+
231
+ attn_output = self.o_proj(attn_output)
232
+
233
+ return attn_output, attn_weights, past_key_value
src/axolotl/utils/models.py CHANGED
@@ -101,12 +101,19 @@ def load_model(
101
  logging.info("patching with flash attention")
102
  replace_llama_attn_with_flash_attn()
103
  elif is_llama_derived_model and cfg.xformers_attention:
104
- from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
105
  hijack_llama_attention,
106
  )
107
 
108
  logging.info("patching with xformers attention")
109
  hijack_llama_attention()
 
 
 
 
 
 
 
110
 
111
  if cfg.bf16:
112
  torch_dtype = torch.bfloat16
 
101
  logging.info("patching with flash attention")
102
  replace_llama_attn_with_flash_attn()
103
  elif is_llama_derived_model and cfg.xformers_attention:
104
+ from axolotl.monkeypatch.llama_attn_hijack_xformers import (
105
  hijack_llama_attention,
106
  )
107
 
108
  logging.info("patching with xformers attention")
109
  hijack_llama_attention()
110
+ elif is_llama_derived_model and cfg.sdp_attention:
111
+ from axolotl.monkeypatch.llama_attn_hijack_xformers import (
112
+ hijack_llama_sdp_attention,
113
+ )
114
+
115
+ logging.info("patching with sdp attention")
116
+ hijack_llama_sdp_attention()
117
 
118
  if cfg.bf16:
119
  torch_dtype = torch.bfloat16