winglian commited on
Commit
ad0ea6a
1 Parent(s): 6cb2310

black formatting

Browse files

ignore copied file
fix linting

.mypy.ini CHANGED
@@ -5,6 +5,9 @@ exclude = venv
5
  [mypy-alpaca_lora_4bit.*]
6
  ignore_missing_imports = True
7
 
 
 
 
8
  [mypy-flash_attn.*]
9
  ignore_missing_imports = True
10
 
@@ -31,3 +34,6 @@ ignore_missing_imports = True
31
 
32
  [mypy-addict]
33
  ignore_missing_imports = True
 
 
 
 
5
  [mypy-alpaca_lora_4bit.*]
6
  ignore_missing_imports = True
7
 
8
+ [mypy-axolotl.monkeypatch.*]
9
+ ignore_errors = True
10
+
11
  [mypy-flash_attn.*]
12
  ignore_missing_imports = True
13
 
 
34
 
35
  [mypy-addict]
36
  ignore_missing_imports = True
37
+
38
+ [mypy-xformers.*]
39
+ ignore_missing_imports = True
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py CHANGED
@@ -1,18 +1,18 @@
1
- '''
2
  Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
3
- '''
4
 
5
  import logging
6
  import math
7
  from typing import Optional, Tuple
8
 
9
  import torch
10
- import torch.nn as nn
11
  import transformers.models.llama.modeling_llama
 
12
 
13
  try:
14
  import xformers.ops
15
- except Exception:
16
  logging.error("xformers not found! Please install it before trying to use it.")
17
 
18
 
@@ -22,7 +22,9 @@ def hijack_llama_attention():
22
 
23
 
24
  def hijack_llama_sdp_attention():
25
- transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
 
 
26
  logging.info("Replaced attention with sdp_attention")
27
 
28
 
@@ -37,15 +39,32 @@ def xformers_forward(
37
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
38
  bsz, q_len, _ = hidden_states.size()
39
 
40
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
41
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
42
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  kv_seq_len = key_states.shape[-2]
45
  if past_key_value is not None:
46
  kv_seq_len += past_key_value[0].shape[-2]
47
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
48
- query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
 
 
 
49
  # [bsz, nh, t, hd]
50
 
51
  if past_key_value is not None:
@@ -65,13 +84,22 @@ def xformers_forward(
65
  # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
66
  if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
67
  # input and output should be of form (bsz, q_len, num_heads, head_dim)
68
- attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
 
 
69
  else:
70
  # input and output should be of form (bsz, q_len, num_heads, head_dim)
71
- attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
 
 
 
 
 
72
  attn_weights = None
73
  else:
74
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
75
 
76
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
77
  raise ValueError(
@@ -85,10 +113,14 @@ def xformers_forward(
85
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
86
  )
87
  attn_weights = attn_weights + attention_mask
88
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
 
 
89
 
90
  # upcast attention to fp32
91
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
 
92
  attn_output = torch.matmul(attn_weights, value_states)
93
 
94
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -115,15 +147,32 @@ def sdp_attention_forward(
115
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
116
  bsz, q_len, _ = hidden_states.size()
117
 
118
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
119
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
120
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  kv_seq_len = key_states.shape[-2]
123
  if past_key_value is not None:
124
  kv_seq_len += past_key_value[0].shape[-2]
125
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
126
- query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
 
 
 
127
  # [bsz, nh, t, hd]
128
 
129
  if past_key_value is not None:
@@ -135,10 +184,18 @@ def sdp_attention_forward(
135
 
136
  # We only apply sdp attention if we don't need to output the whole attention matrix
137
  if not output_attentions:
138
- attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
 
 
 
 
 
 
139
  attn_weights = None
140
  else:
141
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
142
 
143
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
144
  raise ValueError(
@@ -152,10 +209,14 @@ def sdp_attention_forward(
152
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
153
  )
154
  attn_weights = attn_weights + attention_mask
155
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
 
 
156
 
157
  # upcast attention to fp32
158
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
 
159
  attn_output = torch.matmul(attn_weights, value_states)
160
 
161
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
1
+ """
2
  Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
3
+ """
4
 
5
  import logging
6
  import math
7
  from typing import Optional, Tuple
8
 
9
  import torch
 
10
  import transformers.models.llama.modeling_llama
11
+ from torch import nn
12
 
13
  try:
14
  import xformers.ops
15
+ except ImportError:
16
  logging.error("xformers not found! Please install it before trying to use it.")
17
 
18
 
 
22
 
23
 
24
  def hijack_llama_sdp_attention():
25
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = (
26
+ sdp_attention_forward
27
+ )
28
  logging.info("Replaced attention with sdp_attention")
29
 
30
 
 
39
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
40
  bsz, q_len, _ = hidden_states.size()
41
 
42
+ query_states = (
43
+ self.q_proj(hidden_states)
44
+ .view(bsz, q_len, self.num_heads, self.head_dim)
45
+ .transpose(1, 2)
46
+ )
47
+ key_states = (
48
+ self.k_proj(hidden_states)
49
+ .view(bsz, q_len, self.num_heads, self.head_dim)
50
+ .transpose(1, 2)
51
+ )
52
+ value_states = (
53
+ self.v_proj(hidden_states)
54
+ .view(bsz, q_len, self.num_heads, self.head_dim)
55
+ .transpose(1, 2)
56
+ )
57
 
58
  kv_seq_len = key_states.shape[-2]
59
  if past_key_value is not None:
60
  kv_seq_len += past_key_value[0].shape[-2]
61
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
62
+ (
63
+ query_states,
64
+ key_states,
65
+ ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
66
+ query_states, key_states, cos, sin, position_ids
67
+ )
68
  # [bsz, nh, t, hd]
69
 
70
  if past_key_value is not None:
 
84
  # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
85
  if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
86
  # input and output should be of form (bsz, q_len, num_heads, head_dim)
87
+ attn_output = xformers.ops.memory_efficient_attention(
88
+ query_states, key_states, value_states, attn_bias=None
89
+ )
90
  else:
91
  # input and output should be of form (bsz, q_len, num_heads, head_dim)
92
+ attn_output = xformers.ops.memory_efficient_attention(
93
+ query_states,
94
+ key_states,
95
+ value_states,
96
+ attn_bias=xformers.ops.LowerTriangularMask(),
97
+ )
98
  attn_weights = None
99
  else:
100
+ attn_weights = torch.matmul(
101
+ query_states, key_states.transpose(2, 3)
102
+ ) / math.sqrt(self.head_dim)
103
 
104
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
105
  raise ValueError(
 
113
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
114
  )
115
  attn_weights = attn_weights + attention_mask
116
+ attn_weights = torch.max(
117
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
118
+ )
119
 
120
  # upcast attention to fp32
121
+ attn_weights = nn.functional.softmax(
122
+ attn_weights, dim=-1, dtype=torch.float32
123
+ ).to(query_states.dtype)
124
  attn_output = torch.matmul(attn_weights, value_states)
125
 
126
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
147
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
148
  bsz, q_len, _ = hidden_states.size()
149
 
150
+ query_states = (
151
+ self.q_proj(hidden_states)
152
+ .view(bsz, q_len, self.num_heads, self.head_dim)
153
+ .transpose(1, 2)
154
+ )
155
+ key_states = (
156
+ self.k_proj(hidden_states)
157
+ .view(bsz, q_len, self.num_heads, self.head_dim)
158
+ .transpose(1, 2)
159
+ )
160
+ value_states = (
161
+ self.v_proj(hidden_states)
162
+ .view(bsz, q_len, self.num_heads, self.head_dim)
163
+ .transpose(1, 2)
164
+ )
165
 
166
  kv_seq_len = key_states.shape[-2]
167
  if past_key_value is not None:
168
  kv_seq_len += past_key_value[0].shape[-2]
169
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
170
+ (
171
+ query_states,
172
+ key_states,
173
+ ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
174
+ query_states, key_states, cos, sin, position_ids
175
+ )
176
  # [bsz, nh, t, hd]
177
 
178
  if past_key_value is not None:
 
184
 
185
  # We only apply sdp attention if we don't need to output the whole attention matrix
186
  if not output_attentions:
187
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
188
+ query_states,
189
+ key_states,
190
+ value_states,
191
+ attn_mask=attention_mask,
192
+ is_causal=False,
193
+ )
194
  attn_weights = None
195
  else:
196
+ attn_weights = torch.matmul(
197
+ query_states, key_states.transpose(2, 3)
198
+ ) / math.sqrt(self.head_dim)
199
 
200
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
201
  raise ValueError(
 
209
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
210
  )
211
  attn_weights = attn_weights + attention_mask
212
+ attn_weights = torch.max(
213
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
214
+ )
215
 
216
  # upcast attention to fp32
217
+ attn_weights = nn.functional.softmax(
218
+ attn_weights, dim=-1, dtype=torch.float32
219
+ ).to(query_states.dtype)
220
  attn_output = torch.matmul(attn_weights, value_states)
221
 
222
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):