bourdoiscatie commited on
Commit
a0e4338
1 Parent(s): 24225cf

Upload 10 files

Browse files
attn_ref.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def attn_ref(q, k, v, b, sm_scale, dropout_p=0.0, causal=False, upcast=False):
4
+ if upcast:
5
+ q, k, v = q.float(), k.float(), v.float()
6
+ if b is not None:
7
+ b = b.float()
8
+
9
+ if b is not None:
10
+ if (b.shape[0] != q.shape[0]) or (b.shape[1] != q.shape[1]):
11
+ b = b.expand(q.shape[0], q.shape[1], q.shape[2], k.shape[2])
12
+
13
+ ms = torch.arange(q.shape[2], device=q.device).unsqueeze(-1)
14
+ ns = torch.arange(k.shape[2], device=q.device)
15
+
16
+ p = torch.matmul(q, k.transpose(2, 3))
17
+ p *= sm_scale
18
+ if b is not None:
19
+ p += b
20
+
21
+ if causal:
22
+ p = torch.where(ms + k.shape[2] - q.shape[2] >= ns, p, float("-inf"))
23
+
24
+ p = torch.softmax(p.float(), dim=-1).to(q.dtype)
25
+ if dropout_p > 0.0:
26
+ p = torch.dropout(p, dropout_p, train=True)
27
+
28
+ ref_out = torch.matmul(p, v)
29
+ return ref_out
configuration_flash_t5.py CHANGED
@@ -6,7 +6,7 @@ import logging
6
  from transformers import T5Config
7
 
8
  AUTO_MAP = {
9
- "AutoModel": "modeling_flash_t5.FlashT5ForConditionalGeneration",
10
  "AutoModelForSeq2SeqLM": "modeling_flash_t5.FlashT5ForConditionalGeneration",
11
  "AutoModelForTokenClassification": "custom_heads_flash_t5.FlashT5ForTokenClassification",
12
  "AutoModelForQuestionAnswering": "custom_heads_flash_t5.FlashT5ForQuestionAnswering",
@@ -26,7 +26,7 @@ class FlashT5Config(T5Config):
26
  use_randomized_position_encoding=False,
27
  label_smoothing=0.0,
28
  z_loss=None,
29
- attention_type="ref",
30
  max_sequence_length=1024,
31
  attention_dropout_rate=0.0,
32
  alibi_mode="symetric",
@@ -39,9 +39,6 @@ class FlashT5Config(T5Config):
39
  rotary_base=10000,
40
  rotary_interleaved=False,
41
  rotary_scale_base=None,
42
- fire_mlp_width=32,
43
- use_masking=False,
44
- attention_scale=None,
45
  **kwargs,
46
  ):
47
  super().__init__(**kwargs)
@@ -53,7 +50,7 @@ class FlashT5Config(T5Config):
53
  self.use_randomized_position_encoding = use_randomized_position_encoding
54
  self.label_smoothing = label_smoothing
55
  self.z_loss = z_loss
56
- self.attention_type = attention_type
57
  self.max_sequence_length = max_sequence_length
58
  self.alibi_mode = alibi_mode
59
  self.attention_dropout_rate = attention_dropout_rate
@@ -66,9 +63,6 @@ class FlashT5Config(T5Config):
66
  self.rotary_interleaved = rotary_interleaved
67
  self.rotary_scale_base = rotary_scale_base
68
  self.rotary_emb_fraction = rotary_emb_fraction
69
- self.fire_mlp_width = fire_mlp_width
70
- self.use_masking = use_masking
71
- self.attention_scale = attention_scale
72
 
73
  self.auto_map = AUTO_MAP
74
 
 
6
  from transformers import T5Config
7
 
8
  AUTO_MAP = {
9
+ "AutoModel": "modeling_flash_t5.FlashT5EncoderModel",
10
  "AutoModelForSeq2SeqLM": "modeling_flash_t5.FlashT5ForConditionalGeneration",
11
  "AutoModelForTokenClassification": "custom_heads_flash_t5.FlashT5ForTokenClassification",
12
  "AutoModelForQuestionAnswering": "custom_heads_flash_t5.FlashT5ForQuestionAnswering",
 
26
  use_randomized_position_encoding=False,
27
  label_smoothing=0.0,
28
  z_loss=None,
29
+ use_flash_attention=None,
30
  max_sequence_length=1024,
31
  attention_dropout_rate=0.0,
32
  alibi_mode="symetric",
 
39
  rotary_base=10000,
40
  rotary_interleaved=False,
41
  rotary_scale_base=None,
 
 
 
42
  **kwargs,
43
  ):
44
  super().__init__(**kwargs)
 
50
  self.use_randomized_position_encoding = use_randomized_position_encoding
51
  self.label_smoothing = label_smoothing
52
  self.z_loss = z_loss
53
+ self.use_flash_attention = use_flash_attention
54
  self.max_sequence_length = max_sequence_length
55
  self.alibi_mode = alibi_mode
56
  self.attention_dropout_rate = attention_dropout_rate
 
63
  self.rotary_interleaved = rotary_interleaved
64
  self.rotary_scale_base = rotary_scale_base
65
  self.rotary_emb_fraction = rotary_emb_fraction
 
 
 
66
 
67
  self.auto_map = AUTO_MAP
68
 
cross_entropy_loss.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ # Copyright 2024 CATIE. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # Modification to the original version from Unsloth:
17
+ # - return the z-loss
18
+ # - support for torch.compile
19
+
20
+ import triton
21
+ import triton.language as tl
22
+ import torch
23
+
24
+ MAX_FUSED_SIZE = 65536
25
+ next_power_of_2 = triton.next_power_of_2
26
+
27
+ def calculate_settings(n):
28
+ BLOCK_SIZE = next_power_of_2(n)
29
+ if BLOCK_SIZE > MAX_FUSED_SIZE:
30
+ raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
31
+ f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
32
+ num_warps = 4
33
+ if BLOCK_SIZE >= 32768: num_warps = 32
34
+ elif BLOCK_SIZE >= 8192: num_warps = 16
35
+ elif BLOCK_SIZE >= 2048: num_warps = 8
36
+ return BLOCK_SIZE, num_warps
37
+
38
+ @triton.jit
39
+ def _cross_entropy_forward(logits_ptr, logits_row_stride,
40
+ loss_ptr,
41
+ lse_ptr,
42
+ labels_ptr,
43
+ n_cols,
44
+ BLOCK_SIZE: tl.constexpr,
45
+ IS_EVEN: tl.constexpr):
46
+ """
47
+ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
48
+ Pi = exp(xi) / sum(exp(xi))
49
+ CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
50
+ = -y [ x - log[sum(exp(x))] ]
51
+ = y * (log[sum(exp(x))] - x)
52
+ If y == 0: CE_i = 0
53
+ If y == 1: CE_i = logsumexp - x
54
+ """
55
+ row_idx = tl.program_id(0)
56
+ logits_ptr += row_idx * logits_row_stride
57
+ loss_ptr += row_idx
58
+ lse_ptr += row_idx
59
+ labels_ptr += row_idx
60
+
61
+ col_offsets = tl.arange(0, BLOCK_SIZE)
62
+ mask = col_offsets < n_cols
63
+
64
+ # TODO: Fixup int32 locations to int64
65
+ label_idx = tl.load(labels_ptr).to(tl.int32)
66
+ if IS_EVEN:
67
+ logits = tl.load(logits_ptr + col_offsets).to(tl.float32)
68
+ else:
69
+ logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32)
70
+
71
+ max_logits = tl.max(logits, 0)
72
+
73
+ # Maximum stops overflow
74
+ lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
75
+ tl.store(lse_ptr, lse)
76
+
77
+ if label_idx != -100:
78
+ logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
79
+ loss = lse - logits_label
80
+ else:
81
+ loss = 0.0
82
+
83
+ tl.store(loss_ptr, loss)
84
+
85
+ @triton.jit
86
+ def _cross_entropy_backward(logits_ptr, logits_row_stride,
87
+ dinputs_ptr, dinputs_row_stride,
88
+ dloss_ptr, dloss_row_stride,
89
+ dzloss_ptr, dzloss_row_stride,
90
+ lse_ptr,
91
+ labels_ptr,
92
+ n_cols,
93
+ BLOCK_SIZE: tl.constexpr,
94
+ USE_Z_LOSS: tl.constexpr,
95
+ IS_EVEN: tl.constexpr):
96
+ """
97
+ CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
98
+ dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
99
+
100
+ From https://en.wikipedia.org/wiki/LogSumExp
101
+ d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
102
+
103
+ dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
104
+ dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
105
+ dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
106
+
107
+ If y == 0: dC/dx = 0
108
+ If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
109
+ If y == 1 and x != label: dC/dx = exp[x - logsumexp]
110
+ """
111
+
112
+ row_idx = tl.program_id(0)
113
+
114
+ logits_ptr += row_idx * logits_row_stride
115
+ dinputs_ptr += row_idx * dinputs_row_stride
116
+ dloss_ptr += row_idx * dloss_row_stride
117
+ dzloss_ptr += row_idx * dzloss_row_stride
118
+ col_offsets = tl.arange(0, BLOCK_SIZE)
119
+ mask = col_offsets < n_cols
120
+ # TODO: Fixup int32 locations to int64
121
+ label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
122
+
123
+ if label_idx != -100:
124
+ dloss = tl.load(dloss_ptr)
125
+ dzloss = tl.load(dzloss_ptr)
126
+ else:
127
+ dloss = 0.0
128
+ dzloss = 0.0
129
+
130
+ if IS_EVEN:
131
+ logits = tl.load(logits_ptr + col_offsets).to(tl.float32)
132
+ else:
133
+ logits = tl.load(logits_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
134
+
135
+ lse = tl.load(lse_ptr + row_idx)
136
+ probs = tl.exp(logits - lse)
137
+
138
+ probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
139
+ din = dloss * probs
140
+
141
+ # Z_loss
142
+ if USE_Z_LOSS:
143
+ if label_idx != -100:
144
+ dzloss = tl.load(dzloss_ptr)
145
+ else:
146
+ dzloss = 0.0
147
+
148
+ row_minus_max = logits
149
+ numerator = tl.exp(row_minus_max)
150
+ denominator = tl.sum(numerator, axis=0)
151
+ softmax_output = numerator / denominator
152
+ din += softmax_output * dzloss
153
+
154
+ if IS_EVEN:
155
+ tl.store(dinputs_ptr + col_offsets, din)
156
+ else:
157
+ tl.store(dinputs_ptr + col_offsets, din, mask=mask)
158
+
159
+
160
+ # Wrapper for triton kernel for torch.compile - should be unecessary for PyTorch 2.3 ?
161
+ torch.library.define("flasht5::cross_entropy_triton_fwd", "(Tensor logits, Tensor labels, int n_cols, int n_rows, int BLOCK_SIZE, int num_warps) -> (Tensor, Tensor)")
162
+
163
+ @torch.library.impl("flasht5::cross_entropy_triton_fwd", "default")
164
+ def cross_entropy_triton_fwd(logits, labels, n_cols, n_rows, BLOCK_SIZE, num_warps):
165
+ losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
166
+ logsumexp = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
167
+
168
+ _cross_entropy_forward[(n_rows,)](
169
+ logits, logits.stride(0),
170
+ losses,
171
+ logsumexp,
172
+ labels,
173
+ n_cols,
174
+ BLOCK_SIZE = BLOCK_SIZE,
175
+ IS_EVEN=((n_cols % BLOCK_SIZE) == 0),
176
+ num_warps = num_warps,
177
+ )
178
+
179
+ return losses, logsumexp
180
+
181
+
182
+ @torch.library.impl_abstract("flasht5::cross_entropy_triton_fwd", cross_entropy_triton_fwd)
183
+ def cross_entropy_triton_fwd_abstract(logits, labels, n_cols, n_rows, BLOCK_SIZE, num_warps):
184
+ losses = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
185
+ logsumexp = torch.empty(n_rows, dtype=torch.float32, device=logits.device)
186
+
187
+ return losses, logsumexp
188
+
189
+ torch.library.define("flasht5::cross_entropy_triton_bwd", "(Tensor dlosses, Tensor dlogsumexp, Tensor logits, Tensor logsumexp, Tensor labels, float z_loss_factor, int n_cols, int n_rows, int BLOCK_SIZE, int num_warps) -> Tensor")
190
+
191
+ @torch.library.impl("flasht5::cross_entropy_triton_bwd", "default")
192
+ def cross_entropy_triton_bwd(dlosses, dlogsumexp, logits, logsumexp, labels, z_loss_factor, n_cols, n_rows, BLOCK_SIZE, num_warps):
193
+
194
+ dinputs = torch.empty_like(logits)
195
+
196
+ _cross_entropy_backward[(n_rows,)](
197
+ logits, logits.stride(0),
198
+ dinputs, dinputs.stride(0),
199
+ dlosses, dlosses.stride(0),
200
+ dlogsumexp, dlogsumexp.stride(0),
201
+ logsumexp,
202
+ labels,
203
+ n_cols,
204
+ BLOCK_SIZE = BLOCK_SIZE,
205
+ USE_Z_LOSS = (z_loss_factor != 0.0),
206
+ IS_EVEN=((n_cols % BLOCK_SIZE) == 0),
207
+ num_warps = num_warps,
208
+ )
209
+
210
+ return dinputs
211
+
212
+
213
+ @torch.library.impl_abstract("flasht5::cross_entropy_triton_bwd", cross_entropy_triton_bwd)
214
+ def cross_entropy_triton_bwd_abstract(dlosses, dlogsumexp, logits, logsumexp, labels, z_loss_factor, n_cols, n_rows, BLOCK_SIZE, num_warps):
215
+ return torch.empty_like(logits)
216
+
217
+ class Fast_CrossEntropyLoss(torch.autograd.Function):
218
+ @staticmethod
219
+ def forward(ctx, logits, labels, z_loss_factor):
220
+ n_rows, n_cols = logits.shape
221
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
222
+
223
+ losses, logsumexp = torch.ops.flasht5.cross_entropy_triton_fwd(
224
+ logits,
225
+ labels,
226
+ n_cols,
227
+ n_rows,
228
+ BLOCK_SIZE = BLOCK_SIZE,
229
+ num_warps = num_warps
230
+ )
231
+
232
+ ctx.BLOCK_SIZE = BLOCK_SIZE
233
+ ctx.num_warps = num_warps
234
+ ctx.z_loss_factor = z_loss_factor
235
+ ctx.save_for_backward(logits, logsumexp, labels)
236
+ return losses, logsumexp
237
+
238
+ @staticmethod
239
+ def backward(ctx, dlosses, dlogsumexp):
240
+ logits, logsumexp, labels = ctx.saved_tensors
241
+ n_rows, n_cols = logits.shape
242
+
243
+ dinputs = torch.ops.flasht5.cross_entropy_triton_bwd(
244
+ dlosses,
245
+ dlogsumexp,
246
+ logits,
247
+ logsumexp,
248
+ labels,
249
+ ctx.z_loss_factor,
250
+ n_cols,
251
+ n_rows,
252
+ ctx.BLOCK_SIZE,
253
+ ctx.num_warps
254
+ )
255
+ return dinputs, None, None
256
+
257
+ def fast_cross_entropy_loss(logits, labels, z_loss_factor=0.0):
258
+ """
259
+ Arguments:
260
+ logits: (batch, seq_len, vocab_size)
261
+ labels: (batch, seq_len,)
262
+ Returns:
263
+ losses: float
264
+ """
265
+ batch, seq_len, d = logits.shape
266
+ assert(labels.shape == (batch, seq_len))
267
+ assert (d <= MAX_FUSED_SIZE)
268
+
269
+ loss, lse = Fast_CrossEntropyLoss.apply(
270
+ logits.view(batch*seq_len, d),
271
+ labels.view(-1),
272
+ z_loss_factor
273
+ )
274
+
275
+ n_items = torch.count_nonzero(labels != -100)
276
+
277
+ return loss.sum() / n_items, (z_loss_factor * torch.square(lse).sum()) / n_items
custom_heads_flash_t5.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
4
+ import copy
5
+ from typing import Optional, Union, Tuple, List
6
+ from transformers.modeling_outputs import (
7
+ Seq2SeqQuestionAnsweringModelOutput,
8
+ QuestionAnsweringModelOutput,
9
+ TokenClassifierOutput,
10
+ BaseModelOutput,
11
+ Seq2SeqSequenceClassifierOutput,
12
+ SequenceClassifierOutput
13
+ )
14
+
15
+ from .modeling_flash_t5 import FlashT5PreTrainedModel, FlashT5Stack, FlashT5Model, FlashT5EncoderModel
16
+ from .configuration_flash_t5 import FlashT5Config
17
+
18
+
19
+ ################## Encoder only head ##################
20
+ class FlashT5ForTokenClassification(FlashT5PreTrainedModel):
21
+
22
+ def __init__(self, config: FlashT5Config):
23
+ super().__init__(config)
24
+ self.num_labels = config.num_labels
25
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
26
+
27
+ self.encoder = FlashT5Stack(config, self.shared)
28
+ self.dropout = nn.Dropout(config.classifier_dropout)
29
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
30
+
31
+ # Initialize weights and apply final processing
32
+ self.post_init()
33
+
34
+ # Initialize classifier
35
+ self.classifier.weight.data.normal_(mean=0.0, std=config.initializer_factor * 1.0)
36
+ self.classifier.bias.data.zero_()
37
+
38
+ self.model_parallel = False
39
+
40
+ def forward(
41
+ self,
42
+ input_ids: Optional[torch.Tensor] = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ head_mask: Optional[torch.Tensor] = None,
45
+ inputs_embeds: Optional[torch.Tensor] = None,
46
+ labels: Optional[torch.Tensor] = None,
47
+ output_attentions: Optional[bool] = None,
48
+ output_hidden_states: Optional[bool] = None,
49
+ return_dict: Optional[bool] = None,
50
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
51
+ r"""
52
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
53
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
54
+ Returns:
55
+ """
56
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57
+
58
+ outputs = self.encoder(
59
+ input_ids=input_ids,
60
+ attention_mask=attention_mask,
61
+ inputs_embeds=inputs_embeds,
62
+ head_mask=head_mask,
63
+ output_attentions=output_attentions,
64
+ output_hidden_states=output_hidden_states,
65
+ return_dict=return_dict,
66
+ )
67
+
68
+ hidden_states = outputs[0]
69
+ hidden_states = self.dropout(hidden_states)
70
+ logits = self.classifier(hidden_states)
71
+
72
+ loss = None
73
+ if labels is not None:
74
+ loss_fct = nn.CrossEntropyLoss()
75
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
76
+
77
+ if not return_dict:
78
+ output = (logits, outputs[2:-1])
79
+ return ((loss,) + output) if loss is not None else output
80
+
81
+ return TokenClassifierOutput(
82
+ loss=loss,
83
+ logits=logits,
84
+ hidden_states=outputs.hidden_states,
85
+ attentions=outputs.attentions,
86
+ )
87
+
88
+
89
+ class FlashT5ClassificationHead(nn.Module):
90
+ """Head for sentence-level classification tasks."""
91
+
92
+ def __init__(self, config: FlashT5Config):
93
+ super().__init__()
94
+ self.dense = nn.Linear(config.d_model, config.d_model)
95
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
96
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
97
+
98
+ # initialize weights
99
+ factor = config.initializer_factor
100
+ self.dense.weight.data.normal_(mean=0.0, std=factor * ((config.d_model) ** -0.5))
101
+ if hasattr(self.dense, "bias") and self.dense.bias is not None:
102
+ self.dense.bias.data.zero_()
103
+ self.out_proj.weight.data.normal_(mean=0.0, std=factor * ((config.d_model) ** -0.5))
104
+ if hasattr(self.out_proj, "bias") and self.out_proj.bias is not None:
105
+ self.out_proj.bias.data.zero_()
106
+
107
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
108
+ hidden_states = self.dropout(hidden_states)
109
+ hidden_states = self.dense(hidden_states)
110
+ hidden_states = torch.tanh(hidden_states)
111
+ hidden_states = self.dropout(hidden_states)
112
+ hidden_states = self.out_proj(hidden_states)
113
+ return hidden_states
114
+
115
+
116
+ class FlashT5ForSequenceClassification(FlashT5PreTrainedModel):
117
+ _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
118
+
119
+ def __init__(self, config: FlashT5Config):
120
+ super().__init__(config)
121
+ self.model_dim = config.d_model
122
+ self.config.problem_type = None
123
+ self.config.is_encoder_decoder = False
124
+
125
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
126
+
127
+ encoder_config = copy.deepcopy(config)
128
+ encoder_config.is_decoder = False
129
+ encoder_config.is_encoder_decoder = False
130
+ encoder_config.use_cache = False
131
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
132
+ self.classification_head = FlashT5ClassificationHead(config)
133
+
134
+ # Initialize weights and apply final processing
135
+ self.post_init()
136
+
137
+ self.model_parallel = False
138
+
139
+ def forward(
140
+ self,
141
+ input_ids: torch.LongTensor = None,
142
+ attention_mask: Optional[torch.Tensor] = None,
143
+ head_mask: Optional[torch.Tensor] = None,
144
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
145
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
146
+ inputs_embeds: Optional[torch.FloatTensor] = None,
147
+ labels: Optional[torch.LongTensor] = None,
148
+ use_cache: Optional[bool] = None,
149
+ output_attentions: Optional[bool] = None,
150
+ output_hidden_states: Optional[bool] = None,
151
+ return_dict: Optional[bool] = None,
152
+ ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
153
+ r"""
154
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
155
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
156
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
157
+ Returns:
158
+ """
159
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
160
+ if labels is not None:
161
+ use_cache = False
162
+
163
+ if input_ids is None and inputs_embeds is not None:
164
+ raise NotImplementedError(
165
+ f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
166
+ )
167
+
168
+
169
+ outputs = self.encoder(
170
+ input_ids=input_ids,
171
+ attention_mask=attention_mask,
172
+ inputs_embeds=inputs_embeds,
173
+ head_mask=head_mask,
174
+ output_attentions=output_attentions,
175
+ output_hidden_states=output_hidden_states,
176
+ return_dict=return_dict,
177
+ )
178
+ sequence_output = outputs[0]
179
+
180
+ eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
181
+
182
+ if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
183
+ raise ValueError("All examples must have the same number of <eos> tokens.")
184
+ batch_size, _, hidden_size = sequence_output.shape
185
+ sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
186
+ logits = self.classification_head(sentence_representation)
187
+
188
+ loss = None
189
+ if labels is not None:
190
+ labels = labels.to(logits.device)
191
+ if self.config.problem_type is None:
192
+ if self.config.num_labels == 1:
193
+ self.config.problem_type = "regression"
194
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
195
+ self.config.problem_type = "single_label_classification"
196
+ else:
197
+ self.config.problem_type = "multi_label_classification"
198
+
199
+ if self.config.problem_type == "regression":
200
+ loss_fct = nn.MSELoss()
201
+ if self.config.num_labels == 1:
202
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
203
+ else:
204
+ loss = loss_fct(logits, labels)
205
+ elif self.config.problem_type == "single_label_classification":
206
+ loss_fct = nn.CrossEntropyLoss()
207
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
208
+ elif self.config.problem_type == "multi_label_classification":
209
+ loss_fct = nn.BCEWithLogitsLoss()
210
+ loss = loss_fct(logits, labels)
211
+ if not return_dict:
212
+ output = (logits,) + outputs[1:]
213
+ return ((loss,) + output) if loss is not None else output
214
+
215
+ return SequenceClassifierOutput(
216
+ loss=loss,
217
+ logits=logits,
218
+ hidden_states=outputs.hidden_states,
219
+ attentions=outputs.attentions
220
+ )
221
+
222
+
223
+
224
+ ################## Seq2Seq head ##################
225
+ class FlashT5ForQuestionAnswering(FlashT5PreTrainedModel):
226
+ _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
227
+
228
+ def __init__(self, config: FlashT5Config):
229
+ super().__init__(config)
230
+ self.transformer = FlashT5EncoderModel(config)
231
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
232
+
233
+ # Initialize weights and apply final processing
234
+ self.post_init()
235
+
236
+ self.model_parallel = False
237
+
238
+ def forward(
239
+ self,
240
+ input_ids: Optional[torch.LongTensor] = None,
241
+ attention_mask: Optional[torch.FloatTensor] = None,
242
+ head_mask: Optional[torch.FloatTensor] = None,
243
+ inputs_embeds: Optional[torch.FloatTensor] = None,
244
+ start_positions: Optional[torch.LongTensor] = None,
245
+ end_positions: Optional[torch.LongTensor] = None,
246
+ output_attentions: Optional[bool] = None,
247
+ output_hidden_states: Optional[bool] = None,
248
+ return_dict: Optional[bool] = None,
249
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
250
+ r"""
251
+ Returns:
252
+
253
+ Example:
254
+
255
+ ```python
256
+ >>> from transformers import AutoTokenizer, MTxEncoderForQuestionAnswering
257
+
258
+ >>> tokenizer = AutoTokenizer.from_pretrained("MTx-small")
259
+ >>> model = MTxEncoderForQuestionAnswering.from_pretrained("MTx-small")
260
+ >>> input_ids = tokenizer(
261
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
262
+ ... ).input_ids # Batch size 1
263
+ >>> outputs = model(input_ids=input_ids)
264
+ >>> start_logits = outputs.start_logits
265
+ >>> end_logits = outputs.end_logits
266
+ ```"""
267
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
268
+
269
+ outputs = self.transformer(
270
+ input_ids,
271
+ attention_mask=attention_mask,
272
+ head_mask=head_mask,
273
+ inputs_embeds=inputs_embeds,
274
+ output_attentions=output_attentions,
275
+ output_hidden_states=output_hidden_states,
276
+ return_dict=return_dict,
277
+ )
278
+ sequence_output = outputs[0]
279
+
280
+ logits = self.qa_outputs(sequence_output)
281
+ start_logits, end_logits = logits.split(1, dim=-1)
282
+ start_logits = start_logits.squeeze(-1).contiguous()
283
+ end_logits = end_logits.squeeze(-1).contiguous()
284
+
285
+ total_loss = None
286
+ if start_positions is not None and end_positions is not None:
287
+ # If we are on multi-GPU, split add a dimension
288
+ if len(start_positions.size()) > 1:
289
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
290
+ if len(end_positions.size()) > 1:
291
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
292
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
293
+ ignored_index = start_logits.size(1)
294
+ start_positions = start_positions.clamp(0, ignored_index)
295
+ end_positions = end_positions.clamp(0, ignored_index)
296
+
297
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
298
+ start_loss = loss_fct(start_logits, start_positions)
299
+ end_loss = loss_fct(end_logits, end_positions)
300
+ total_loss = (start_loss + end_loss) / 2
301
+
302
+ if not return_dict:
303
+ output = (start_logits, end_logits) + outputs[1:]
304
+ return ((total_loss,) + output) if total_loss is not None else output
305
+
306
+ return QuestionAnsweringModelOutput(
307
+ loss=total_loss,
308
+ start_logits=start_logits,
309
+ end_logits=end_logits,
310
+ hidden_states=outputs.hidden_states,
311
+ attentions=outputs.attentions,
312
+ )
fa2_compilable.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ # isort: off
9
+ # We need to import the CUDA kernels after importing torch
10
+ import flash_attn_2_cuda as flash_attn_cuda
11
+
12
+ # isort: on
13
+
14
+ torch.library.define("fa2::fwd", "(Tensor q, Tensor k, Tensor v, Tensor out, Tensor alibi_slopes, float dropout_p, float softmax_scale, bool causal, int window_size_left, int window_size_right, Tensor attn_bias, bool return_softmax, Tensor gen_) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)")
15
+
16
+ @torch.library.impl("fa2::fwd", "default")
17
+ def cuda_fa2_fwd(
18
+ q: torch.Tensor,
19
+ k: torch.Tensor,
20
+ v: torch.Tensor,
21
+ out: torch.Tensor,
22
+ alibi_slopes: torch.Tensor,
23
+ dropout_p: float,
24
+ softmax_scale: float,
25
+ causal: bool,
26
+ window_size_left: int,
27
+ window_size_right: int,
28
+ attn_bias: torch.Tensor,
29
+ return_softmax: bool,
30
+ gen_: torch.Tensor,
31
+ ):
32
+
33
+ out, q, k, v, out_padded, attn_bias, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(q, k, v, out, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left, window_size_right, attn_bias, return_softmax, None)
34
+ return out, q, k, v, out_padded, attn_bias, softmax_lse, S_dmask, rng_state
35
+
36
+ @torch.library.impl_abstract("fa2::fwd", cuda_fa2_fwd)
37
+ def meta_fa2_fwd(
38
+ q: torch.Tensor,
39
+ k: torch.Tensor,
40
+ v: torch.Tensor,
41
+ out: torch.Tensor,
42
+ alibi_slopes: torch.Tensor,
43
+ dropout_p: float,
44
+ softmax_scale: float,
45
+ causal: bool,
46
+ window_size_left: int,
47
+ window_size_right: int,
48
+ attn_bias: torch.Tensor,
49
+ return_softmax: bool,
50
+ gen_: torch.Tensor
51
+ ):
52
+
53
+ round_multiple = lambda x, m: (x + m - 1) // m * m
54
+ batch_size = q.shape[0]
55
+ seqlen_q = q.shape[1]
56
+ seqlen_k = k.shape[1]
57
+ num_heads = q.shape[2]
58
+ head_dim_og = q.shape[3]
59
+ seqlen_q_rounded = round_multiple(seqlen_q, 128)
60
+ seqlen_k_rounded = round_multiple(seqlen_k, 128)
61
+ seqlen_q_rounded_8 = round_multiple(seqlen_q, 8)
62
+ seqlen_k_rounded_8 = round_multiple(seqlen_k, 8)
63
+ head_dim = round_multiple(head_dim_og, 8)
64
+
65
+ if attn_bias is not None:
66
+ batch_size_bias = attn_bias.shape[0]
67
+ num_heads_bias = attn_bias.shape[1]
68
+
69
+ return (torch.empty_strided((batch_size, seqlen_q, num_heads, head_dim_og),
70
+ (head_dim*num_heads*seqlen_q, head_dim*num_heads, head_dim, 1), device=q.device, dtype=q.dtype), # out
71
+ q.new_empty((batch_size, seqlen_q, num_heads, head_dim)), # q_padded
72
+ k.new_empty((batch_size, seqlen_k, num_heads, head_dim)), # k_padded
73
+ v.new_empty((batch_size, seqlen_k, num_heads, head_dim)), # v_padded
74
+ q.new_empty((batch_size, seqlen_q, num_heads, head_dim)), # out_padded
75
+ q.new_empty((batch_size_bias, num_heads_bias, seqlen_q_rounded_8, seqlen_k_rounded_8)) if attn_bias is not None else None, # attn_bias
76
+ q.new_empty((batch_size, num_heads, seqlen_q)), # softmax_lse
77
+ q.new_empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded)) if return_softmax and (dropout_p > 0) else None, # p
78
+ torch.empty((2), dtype=torch.int64, device=q.device) # rng_state
79
+ )
80
+
81
+ torch.library.define("fa2::bwd", "(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor dq, Tensor dk, Tensor dv, Tensor alibi_slopes, float dropout_p, float softmax_scale, bool causal, int window_size_left, int window_size_right, bool deterministic, Tensor attn_bias, bool attn_bias_require_grad, Tensor ds, int seqlen_k_orig, Tensor gen_, Tensor rng_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)")
82
+
83
+ @torch.library.impl("fa2::bwd", "default")
84
+ def cuda_fa2_bwd(
85
+ dout: torch.Tensor,
86
+ q: torch.Tensor,
87
+ k: torch.Tensor,
88
+ v: torch.Tensor,
89
+ out: torch.Tensor,
90
+ softmax_lse: torch.Tensor,
91
+ dq: torch.Tensor,
92
+ dk: torch.Tensor,
93
+ dv: torch.Tensor,
94
+ alibi_slopes: torch.Tensor,
95
+ dropout_p: float,
96
+ softmax_scale: float,
97
+ causal: bool,
98
+ window_size_left: int,
99
+ window_size_right: int,
100
+ deterministic: bool,
101
+ attn_bias: torch.Tensor,
102
+ attn_bias_require_grad: bool,
103
+ ds: torch.Tensor,
104
+ seqlen_k_orig: int,
105
+ gen_: torch.Tensor,
106
+ rng_sate: torch.Tensor
107
+ ):
108
+ dq, dk, dv, ds, s = flash_attn_cuda.bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left, window_size_right, deterministic, attn_bias, attn_bias_require_grad, ds, None, rng_sate)
109
+ return dq, dk, dv, ds, s
110
+
111
+ @torch.library.impl_abstract("fa2::bwd", cuda_fa2_bwd)
112
+ def meta_fa2_bwd(
113
+ dout: torch.Tensor,
114
+ q: torch.Tensor,
115
+ k: torch.Tensor,
116
+ v: torch.Tensor,
117
+ out: torch.Tensor,
118
+ softmax_lse: torch.Tensor,
119
+ dq: torch.Tensor,
120
+ dk: torch.Tensor,
121
+ dv: torch.Tensor,
122
+ alibi_slopes: torch.Tensor,
123
+ dropout_p: float,
124
+ softmax_scale: float,
125
+ causal: bool,
126
+ window_size_left: int,
127
+ window_size_right: int,
128
+ deterministic: bool,
129
+ attn_bias: torch.Tensor,
130
+ attn_bias_require_grad: bool,
131
+ ds: torch.Tensor,
132
+ seqlen_k_orig: int,
133
+ gen_: torch.Tensor,
134
+ rng_sate: torch.Tensor
135
+ ):
136
+
137
+ round_multiple = lambda x, m: (x + m - 1) // m * m
138
+ batch_size = dout.shape[0]
139
+ seqlen_q = dout.shape[1]
140
+ seqlen_k = k.shape[1]
141
+ seqlen_q_rounded = round_multiple(seqlen_q, 128)
142
+ num_heads = dout.shape[2]
143
+ head_dim_og = dout.shape[3]
144
+ head_dim = round_multiple(head_dim_og, 8)
145
+ seqlen_q_round8 = round_multiple(seqlen_q, 8)
146
+ seqlen_k_round8 = round_multiple(seqlen_k_orig, 8)
147
+
148
+ if attn_bias is not None:
149
+ batch_size_bias = attn_bias.shape[0]
150
+ num_heads_bias = attn_bias.shape[1]
151
+
152
+ return (torch.empty_strided((batch_size, seqlen_q, num_heads, head_dim_og),
153
+ (head_dim*num_heads*seqlen_q, head_dim*num_heads, head_dim, 1), device=q.device, dtype=q.dtype),
154
+ torch.empty_strided((batch_size, seqlen_k_orig, num_heads, head_dim_og),
155
+ (head_dim*num_heads*seqlen_k, head_dim*num_heads, head_dim, 1), device=k.device, dtype=k.dtype),
156
+ torch.empty_strided((batch_size, seqlen_k, num_heads, head_dim_og),
157
+ (head_dim*num_heads*seqlen_k, head_dim*num_heads, head_dim, 1), device=v.device, dtype=v.dtype),
158
+ torch.empty_strided((batch_size_bias, num_heads_bias, seqlen_q, seqlen_k_orig),
159
+ (num_heads_bias*seqlen_q_round8*seqlen_k_round8, seqlen_q_round8*seqlen_k_round8, seqlen_q_round8, 1), device=v.device, dtype=v.dtype)
160
+ if attn_bias_require_grad else None,
161
+ q.new_empty((batch_size, num_heads, seqlen_q_rounded))
162
+ )
163
+
164
+
165
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
166
+ @staticmethod
167
+ def forward(
168
+ ctx,
169
+ qkv,
170
+ dropout_p,
171
+ softmax_scale,
172
+ causal,
173
+ window_size_left,
174
+ window_size_right,
175
+ alibi_slopes,
176
+ deterministic,
177
+ attn_bias,
178
+ return_softmax,
179
+ return_ds
180
+ ):
181
+ if softmax_scale is None:
182
+ softmax_scale = qkv.shape[-1] ** (-0.5)
183
+
184
+ out, q_padded, k_padded, v_padded, out_padded, attn_bias_padded, softmax_lse, S_dmask, rng_state = torch.ops.fa2.fwd(
185
+ qkv[:, :, 0],
186
+ qkv[:, :, 1],
187
+ qkv[:, :, 2],
188
+ None,
189
+ alibi_slopes,
190
+ dropout_p,
191
+ softmax_scale,
192
+ causal,
193
+ window_size_left,
194
+ window_size_right,
195
+ attn_bias,
196
+ return_softmax and dropout_p > 0,
197
+ None
198
+ )
199
+
200
+ ## WORKAROUND a Pytorch bug, should use _padded version of the tensors but this is buggy when passing them directly to save_for_backward
201
+ ## For now, this breaks the backward when headdim is not a multiple of 8 and/or seqlen_q, seqlen_k are not a multiple of 8
202
+ ## TODO: make the padding here instead
203
+ ctx.save_for_backward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], out, softmax_lse, rng_state, attn_bias, alibi_slopes)
204
+ #ctx.save_for_backward(q_padded, k_padded, v_padded, out_padded, softmax_lse, rng_state, attn_bias_padded, alibi_slopes)
205
+ ctx.dropout_p = dropout_p
206
+ ctx.softmax_scale = softmax_scale
207
+ ctx.causal = causal
208
+ ctx.window_size_left = window_size_left
209
+ ctx.window_size_right = window_size_right
210
+ ctx.deterministic = deterministic
211
+ ctx.bias_requires_grad = True if attn_bias is not None and return_ds else False
212
+ ctx.seqlen_k_orig = qkv.shape[1]
213
+
214
+ return out if not return_softmax else (out, softmax_lse, S_dmask)
215
+
216
+ @staticmethod
217
+ def backward(ctx, dout, *args):
218
+ q, k, v, out, softmax_lse, rng_state, attn_bias, alibi_slopes = ctx.saved_tensors
219
+
220
+ dq, dk, dv, ds, _ = torch.ops.fa2.bwd(
221
+ dout,
222
+ q,
223
+ k,
224
+ v,
225
+ out,
226
+ softmax_lse,
227
+ None,
228
+ None,
229
+ None,
230
+ alibi_slopes,
231
+ ctx.dropout_p,
232
+ ctx.softmax_scale,
233
+ ctx.causal,
234
+ ctx.window_size_left,
235
+ ctx.window_size_right,
236
+ ctx.deterministic,
237
+ attn_bias,
238
+ ctx.bias_requires_grad,
239
+ None,
240
+ ctx.seqlen_k_orig,
241
+ None,
242
+ rng_state
243
+ )
244
+ dqkv = torch.stack([dq, dk, dv], dim=2)
245
+ return dqkv, None, None, None, None, None, None, None, ds, None, None
246
+
247
+ class FlashAttnKVPackedFunc(torch.autograd.Function):
248
+ @staticmethod
249
+ def forward(
250
+ ctx,
251
+ q,
252
+ kv,
253
+ dropout_p,
254
+ softmax_scale,
255
+ causal,
256
+ window_size_left,
257
+ window_size_right,
258
+ alibi_slopes,
259
+ deterministic,
260
+ attn_bias,
261
+ return_softmax,
262
+ return_ds
263
+ ):
264
+ if softmax_scale is None:
265
+ softmax_scale = q.shape[-1] ** (-0.5)
266
+
267
+ out, q_padded, k_padded, v_padded, out_padded, attn_bias_padded, softmax_lse, S_dmask, rng_state = torch.ops.fa2.fwd(
268
+ q,
269
+ kv[:, :, 0],
270
+ kv[:, :, 1],
271
+ None,
272
+ alibi_slopes,
273
+ dropout_p,
274
+ softmax_scale,
275
+ causal,
276
+ window_size_left,
277
+ window_size_right,
278
+ attn_bias,
279
+ return_softmax and dropout_p > 0,
280
+ None
281
+ )
282
+
283
+ ## WORKAROUND a Pytorch bug, should use _padded version of the tensors but this is buggy when passing them directly to save_for_backward
284
+ ## For now, this breaks the backward when headdim is not a multiple of 8 and/or seqlen_q, seqlen_k are not a multiple of 8
285
+ ## TODO: make the padding here instead
286
+ ctx.save_for_backward(q, kv[:, :, 0], kv[:, :, 1], out, softmax_lse, rng_state, attn_bias, alibi_slopes)
287
+ #ctx.save_for_backward(q_padded, k_padded, v_padded, out_padded, softmax_lse, rng_state, attn_bias_padded, alibi_slopes)
288
+ ctx.dropout_p = dropout_p
289
+ ctx.softmax_scale = softmax_scale
290
+ ctx.causal = causal
291
+ ctx.window_size_left = window_size_left
292
+ ctx.window_size_right = window_size_right
293
+ ctx.deterministic = deterministic
294
+ ctx.bias_requires_grad = True if attn_bias is not None and return_ds else False
295
+ ctx.seqlen_k_orig = kv.shape[1]
296
+ return out if not return_softmax else (out, softmax_lse, S_dmask)
297
+
298
+ @staticmethod
299
+ def backward(ctx, dout, *args):
300
+ q, k, v, out, softmax_lse, rng_state, attn_bias, alibi_slopes = ctx.saved_tensors
301
+
302
+ dq, dk, dv, ds, _ = torch.ops.fa2.bwd(
303
+ dout,
304
+ q,
305
+ k,
306
+ v,
307
+ out,
308
+ softmax_lse,
309
+ None,
310
+ None,
311
+ None,
312
+ alibi_slopes,
313
+ ctx.dropout_p,
314
+ ctx.softmax_scale,
315
+ ctx.causal,
316
+ ctx.window_size_left,
317
+ ctx.window_size_right,
318
+ ctx.deterministic,
319
+ attn_bias,
320
+ ctx.bias_requires_grad,
321
+ None,
322
+ ctx.seqlen_k_orig,
323
+ None,
324
+ rng_state
325
+ )
326
+ dkv = torch.stack([dk, dv], dim=2)
327
+
328
+ return dq, dkv, None, None, None, None, None, None, None, ds, None, None
329
+
330
+ class FlashAttnFunc(torch.autograd.Function):
331
+ @staticmethod
332
+ def forward(
333
+ ctx,
334
+ q,
335
+ k,
336
+ v,
337
+ dropout_p,
338
+ softmax_scale,
339
+ causal,
340
+ window_size_left,
341
+ window_size_right,
342
+ alibi_slopes,
343
+ deterministic,
344
+ attn_bias,
345
+ return_softmax,
346
+ return_ds
347
+ ):
348
+
349
+ batch_size, seqlen_q = q.shape[:2]
350
+ seqlen_k = k.shape[1]
351
+
352
+ if softmax_scale is None:
353
+ softmax_scale = q.shape[-1] ** (-0.5)
354
+
355
+ if attn_bias is not None:
356
+ attn_bias = attn_bias.to(q.dtype)
357
+
358
+ out, q_padded, k_padded, v_padded, out_padded, attn_bias_padded, softmax_lse, S_dmask, rng_state = torch.ops.fa2.fwd(
359
+ q,
360
+ k,
361
+ v,
362
+ None,
363
+ alibi_slopes,
364
+ dropout_p,
365
+ softmax_scale,
366
+ causal,
367
+ window_size_left,
368
+ window_size_right,
369
+ attn_bias,
370
+ return_softmax and dropout_p > 0,
371
+ None
372
+ )
373
+
374
+ ## WORKAROUND a Pytorch bug, should use _padded version of the tensors but this is buggy when passing them directly to save_for_backward
375
+ ## For now, this breaks the backward when headdim is not a multiple of 8 and/or seqlen_q, seqlen_k are not a multiple of 8
376
+ ## TODO: make the padding here instead
377
+ ctx.save_for_backward(q, k, v, out, softmax_lse, rng_state, attn_bias, alibi_slopes)
378
+ #ctx.save_for_backward(q_padded, k_padded, v_padded, out_padded, softmax_lse, rng_state, attn_bias_padded, alibi_slopes)
379
+
380
+ ctx.dropout_p = dropout_p
381
+ ctx.softmax_scale = softmax_scale
382
+ ctx.causal = causal
383
+ ctx.window_size_left = window_size_left
384
+ ctx.window_size_right = window_size_right
385
+ ctx.deterministic = deterministic
386
+ ctx.bias_requires_grad = True if attn_bias is not None and return_ds else False
387
+ ctx.seqlen_k_orig = k.shape[1]
388
+
389
+ return out if not return_softmax else (out, softmax_lse, S_dmask)
390
+
391
+ @staticmethod
392
+ def backward(ctx, dout, *args):
393
+ q, k, v, out, softmax_lse, rng_state, attn_bias, alibi_slopes = ctx.saved_tensors
394
+
395
+ dout = dout.contiguous()
396
+ dq, dk, dv, ds, _ = torch.ops.fa2.bwd(
397
+ dout,
398
+ q,
399
+ k,
400
+ v,
401
+ out,
402
+ softmax_lse,
403
+ None,
404
+ None,
405
+ None,
406
+ alibi_slopes,
407
+ ctx.dropout_p,
408
+ ctx.softmax_scale,
409
+ ctx.causal,
410
+ ctx.window_size_left,
411
+ ctx.window_size_right,
412
+ ctx.deterministic,
413
+ attn_bias,
414
+ ctx.bias_requires_grad,
415
+ None,
416
+ ctx.seqlen_k_orig,
417
+ None,
418
+ rng_state
419
+ )
420
+
421
+ return dq, dk, dv, None, None, None, None, None, None, None, ds, None, None
422
+
423
+
424
+ def flash_attn_qkvpacked_func(
425
+ qkv,
426
+ dropout_p=0.0,
427
+ softmax_scale=None,
428
+ causal=False,
429
+ window_size_left=-1,
430
+ window_size_right=-1, # -1 means infinite context window
431
+ alibi_slopes=None,
432
+ deterministic=False,
433
+ attn_bias=None,
434
+ return_attn_probs=False,
435
+ return_ds=False
436
+ ):
437
+ """dropout_p should be set to 0.0 during evaluation
438
+ If Q, K, V are already stacked into 1 tensor, this function will be faster than
439
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
440
+ of the gradients of Q, K, V.
441
+ For multi-query and grouped-query attention (MQA/GQA), please see
442
+ flash_attn_kvpacked_func and flash_attn_func.
443
+
444
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
445
+ will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
446
+
447
+ Arguments:
448
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
449
+ dropout_p: float. Dropout probability.
450
+ softmax_scale: float. The scaling of QK^T before applying softmax.
451
+ Default to 1 / sqrt(headdim).
452
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
453
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
454
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
455
+ the attention score of query i and key j.
456
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
457
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
458
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
459
+ testing only. The returned probabilities are not guaranteed to be correct
460
+ (they might not have the right scaling).
461
+ Return:
462
+ out: (batch_size, seqlen, nheads, headdim).
463
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
464
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
465
+ normalization factor).
466
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
467
+ The output of softmax (possibly with different scaling). It also encodes the dropout
468
+ pattern (negative means that location was dropped, nonnegative means it was kept).
469
+ """
470
+ return FlashAttnQKVPackedFunc.apply(
471
+ qkv,
472
+ dropout_p,
473
+ softmax_scale,
474
+ causal,
475
+ window_size_left,
476
+ window_size_right,
477
+ alibi_slopes,
478
+ deterministic,
479
+ attn_bias,
480
+ return_attn_probs,
481
+ return_ds
482
+ )
483
+
484
+
485
+ def flash_attn_kvpacked_func(
486
+ q,
487
+ kv,
488
+ dropout_p=0.0,
489
+ softmax_scale=None,
490
+ causal=False,
491
+ window_size_left=-1,
492
+ window_size_right=-1, # -1 means infinite context window
493
+ alibi_slopes=None,
494
+ deterministic=False,
495
+ attn_bias=None,
496
+ return_attn_probs=False,
497
+ return_ds=False
498
+ ):
499
+ """dropout_p should be set to 0.0 during evaluation
500
+ If K, V are already stacked into 1 tensor, this function will be faster than
501
+ calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
502
+ of the gradients of K, V.
503
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
504
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
505
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
506
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
507
+
508
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
509
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
510
+ 1 1 1 1 0
511
+ 1 1 1 1 1
512
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
513
+ 0 0
514
+ 0 0
515
+ 0 0
516
+ 1 0
517
+ 1 1
518
+ If the row of the mask is all zero, the output will be zero.
519
+
520
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
521
+ will only attend to keys between
522
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
523
+
524
+ Arguments:
525
+ q: (batch_size, seqlen, nheads, headdim)
526
+ kv: (batch_size, seqlen, 2, nheads_k, headdim)
527
+ dropout_p: float. Dropout probability.
528
+ softmax_scale: float. The scaling of QK^T before applying softmax.
529
+ Default to 1 / sqrt(headdim).
530
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
531
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
532
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
533
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
534
+ is added to the attention score of query i and key j.
535
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
536
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
537
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
538
+ testing only. The returned probabilities are not guaranteed to be correct
539
+ (they might not have the right scaling).
540
+ Return:
541
+ out: (batch_size, seqlen, nheads, headdim).
542
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
543
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
544
+ normalization factor).
545
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
546
+ The output of softmax (possibly with different scaling). It also encodes the dropout
547
+ pattern (negative means that location was dropped, nonnegative means it was kept).
548
+ """
549
+ return FlashAttnKVPackedFunc.apply(
550
+ q,
551
+ kv,
552
+ dropout_p,
553
+ softmax_scale,
554
+ causal,
555
+ window_size_left,
556
+ window_size_right,
557
+ alibi_slopes,
558
+ deterministic,
559
+ attn_bias,
560
+ return_attn_probs,
561
+ return_ds
562
+ )
563
+
564
+
565
+ def flash_attn_func(
566
+ q,
567
+ k,
568
+ v,
569
+ dropout_p=0.0,
570
+ softmax_scale=None,
571
+ causal=False,
572
+ window_size_left=-1,
573
+ window_size_right=-1, # -1 means infinite context window
574
+ alibi_slopes=None,
575
+ deterministic=False,
576
+ attn_bias=None,
577
+ return_attn_probs=False,
578
+ return_ds=False
579
+ ):
580
+ """dropout_p should be set to 0.0 during evaluation
581
+ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
582
+ than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
583
+ For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
584
+ 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
585
+
586
+ If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
587
+ For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
588
+ 1 1 1 1 0
589
+ 1 1 1 1 1
590
+ If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
591
+ 0 0
592
+ 0 0
593
+ 0 0
594
+ 1 0
595
+ 1 1
596
+ If the row of the mask is all zero, the output will be zero.
597
+
598
+ If window_size != (-1, -1), implements sliding window local attention. Query at position i
599
+ will only attend to keys between
600
+ [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
601
+
602
+ Arguments:
603
+ q: (batch_size, seqlen, nheads, headdim)
604
+ k: (batch_size, seqlen, nheads_k, headdim)
605
+ v: (batch_size, seqlen, nheads_k, headdim)
606
+ dropout_p: float. Dropout probability.
607
+ softmax_scale: float. The scaling of QK^T before applying softmax.
608
+ Default to 1 / sqrt(headdim).
609
+ causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
610
+ window_size: (left, right). If not (-1, -1), implements sliding window local attention.
611
+ alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
612
+ (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
613
+ is added to the attention score of query i and key j.
614
+ deterministic: bool. Whether to use the deterministic implementation of the backward pass,
615
+ which is slightly slower and uses more memory. The forward pass is always deterministic.
616
+ return_attn_probs: bool. Whether to return the attention probabilities. This option is for
617
+ testing only. The returned probabilities are not guaranteed to be correct
618
+ (they might not have the right scaling).
619
+ Return:
620
+ out: (batch_size, seqlen, nheads, headdim).
621
+ softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
622
+ logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
623
+ normalization factor).
624
+ S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
625
+ The output of softmax (possibly with different scaling). It also encodes the dropout
626
+ pattern (negative means that location was dropped, nonnegative means it was kept).
627
+ """
628
+ return FlashAttnFunc.apply(
629
+ q,
630
+ k,
631
+ v,
632
+ dropout_p,
633
+ softmax_scale,
634
+ causal,
635
+ window_size_left,
636
+ window_size_right,
637
+ alibi_slopes,
638
+ deterministic,
639
+ attn_bias,
640
+ return_attn_probs,
641
+ return_ds,
642
+ )
flash_attention_v2_bias.py ADDED
@@ -0,0 +1,859 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 BAAI
2
+ # Copyright 2024 CATIE
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # Modifications to the orignal file
17
+ # - Support for biases following https://github.com/FlagOpen/FlagAttention/pull/5
18
+ # - Support for shape (1,1,q,k) biases
19
+
20
+ import math
21
+ import torch
22
+ import triton
23
+ import triton.language as tl
24
+
25
+ # Wrapper for triton kernel for torch.compile - should be unecessary for PyTorch 2.3 ?
26
+ torch.library.define("flasht5::flash_attn_v2_fwd", "(Tensor q, Tensor k, Tensor v, Tensor bias, bool causal, float sm_scale, int BLOCK_M, int BLOCK_N, int num_warps, int num_stages) -> (Tensor, Tensor)")
27
+
28
+ @torch.library.impl("flasht5::flash_attn_v2_fwd", "default")
29
+ def flash_attn_v2_fwd(q, k, v, bias, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages):
30
+
31
+ B, H, M, D = q.shape
32
+ N = k.shape[2]
33
+ P_SEQ = N - M
34
+ larger_m = M > N
35
+
36
+ # Trick to support shape such as (1, 1, seqlen_q, seqlen_k)
37
+ bias_batch_stride = bias.stride(0) if bias is not None else 0
38
+ bias_heads_stride = bias.stride(1) if bias is not None else 0
39
+ if bias is not None:
40
+ if (bias.shape[0] != q.shape[0]) and (bias.shape[0] == 1):
41
+ bias_batch_stride = 0
42
+ if (bias.shape[1] != q.shape[1]) and (bias.shape[1] == 1):
43
+ bias_heads_stride = 0
44
+
45
+ divisible_m = M % BLOCK_M == 0
46
+ divisible_n = N % BLOCK_N == 0
47
+ # consider using 3d grid to avoid div & rem
48
+ grid = (triton.cdiv(M, BLOCK_M), H, B)
49
+ o = torch.empty_like(q)
50
+ L = torch.empty((B, H, M), device=q.device, dtype=torch.float32)
51
+
52
+ _fwd_kernel[grid](
53
+ q, k, v, bias, sm_scale,
54
+ L, o,
55
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
56
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
57
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
58
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
59
+ bias_batch_stride, bias_heads_stride,
60
+ bias.stride(2) if bias is not None else 0,
61
+ bias.stride(3) if bias is not None else 0,
62
+ B, H, M, N, P_SEQ,
63
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,
64
+ IS_CAUSAL=causal, LARGER_M=larger_m,
65
+ DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
66
+ HAS_BIAS=(bias is not None),
67
+ num_warps=num_warps, num_stages=num_stages,
68
+ )
69
+
70
+ return o, L
71
+
72
+
73
+ @torch.library.impl_abstract("flasht5::flash_attn_v2_fwd", flash_attn_v2_fwd)
74
+ def flash_attn_v2_fwd_abstract(q, k, v, bias, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages):
75
+ B, H, M, D = q.shape
76
+ o = torch.empty_like(q)
77
+ L = torch.empty((B, H, M), dtype=torch.float32, device=q.device)
78
+
79
+ return o, L
80
+
81
+ torch.library.define("flasht5::flash_attn_v2_bwd", "(Tensor o, Tensor do, Tensor q, Tensor k, Tensor v, Tensor bias, Tensor L, bool causal, float sm_scale, int BLOCK_M, int BLOCK_N, int num_warps, int num_stages) -> (Tensor, Tensor, Tensor, Tensor)")
82
+
83
+ @torch.library.impl("flasht5::flash_attn_v2_bwd", "default")
84
+ def flash_attn_v2_bwd(o, do, q, k, v, bias, L, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages):
85
+
86
+ B, H, M, D = q.shape
87
+ N = k.shape[2]
88
+ P_SEQ = N - M
89
+ larger_m = M > N
90
+
91
+ divisible_m = M % BLOCK_M == 0
92
+ divisible_n = N % BLOCK_N == 0
93
+
94
+ # Trick to support shape such as (1, 1, seqlen_q, seqlen_k)
95
+ bias_batch_stride = bias.stride(0) if bias is not None else 0
96
+ bias_heads_stride = bias.stride(1) if bias is not None else 0
97
+ if bias is not None:
98
+ if (bias.shape[0] != q.shape[0]) and (bias.shape[0] == 1):
99
+ bias_batch_stride = 0
100
+ if (bias.shape[1] != q.shape[1]) and (bias.shape[1] == 1):
101
+ bias_heads_stride = 0
102
+
103
+ delta = torch.empty_like(L)
104
+ grid = (triton.cdiv(M, BLOCK_M), H, B)
105
+
106
+ _bwd_preprocess[grid](
107
+ o, do,
108
+ delta,
109
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
110
+ do.stride(0), do.stride(1), do.stride(2), do.stride(3),
111
+ delta.stride(0), delta.stride(1), delta.stride(2),
112
+ M,
113
+ BLOCK_M=BLOCK_M, D_HEAD=D,
114
+ DIVISIBLE_M=divisible_m,
115
+ )
116
+
117
+ dk = torch.empty_like(k)
118
+ dv = torch.empty_like(v)
119
+
120
+ HAS_BIAS = bias is not None
121
+ RETURN_DS = HAS_BIAS
122
+ USE_DS_ATOMIC_ADD = (bias_batch_stride == 0) or (bias_heads_stride == 0)
123
+ ds = None
124
+ if RETURN_DS:
125
+ ds = torch.empty_like(bias)
126
+ if USE_DS_ATOMIC_ADD:
127
+ ds = ds.zero_()
128
+
129
+ grid = (triton.cdiv(N, BLOCK_N), H, B)
130
+ _bwd_kv_kernel[grid](
131
+ q, k, v, bias, sm_scale, do,
132
+ dk, dv, ds,
133
+ L, delta,
134
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
135
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
136
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
137
+ bias_batch_stride, bias_heads_stride,
138
+ bias.stride(2) if HAS_BIAS else 0,
139
+ bias.stride(3) if HAS_BIAS else 0,
140
+ do.stride(0), do.stride(1), do.stride(2), do.stride(3),
141
+ dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),
142
+ dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),
143
+ B, H, M, N, P_SEQ,
144
+ BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal,
145
+ DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
146
+ HAS_BIAS=HAS_BIAS,
147
+ RETURN_DS=RETURN_DS, USE_DS_ATOMIC_ADD=USE_DS_ATOMIC_ADD,
148
+ num_stages=num_stages, num_warps=num_warps,
149
+ )
150
+
151
+ dq = torch.empty_like(q)
152
+ grid = (triton.cdiv(M, BLOCK_M), H, B)
153
+ _bwd_q_kernel[grid](
154
+ q, k, v, bias, sm_scale, do,
155
+ dq,
156
+ L, delta,
157
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
158
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
159
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
160
+ bias_batch_stride, bias_heads_stride,
161
+ bias.stride(2) if HAS_BIAS else 0,
162
+ bias.stride(3) if HAS_BIAS else 0,
163
+ do.stride(0), do.stride(1), do.stride(2), do.stride(3),
164
+ dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),
165
+ B, H, M, N, P_SEQ,
166
+ BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
167
+ CAUSAL=causal, LARGER_M=larger_m,
168
+ DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
169
+ HAS_BIAS=HAS_BIAS,
170
+ num_stages=num_stages, num_warps = num_warps,
171
+ )
172
+
173
+ return dq, dk, dv, ds
174
+
175
+ @torch.library.impl_abstract("flasht5::flash_attn_v2_bwd", flash_attn_v2_bwd)
176
+ def cross_entropy_triton_bwd_abstract(o, do, q, k, v, bias, L, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages):
177
+ dq = torch.empty_like(q)
178
+ dk = torch.empty_like(k)
179
+ dv = torch.empty_like(v)
180
+ ds = torch.empty_like(bias) if bias is not None else None
181
+
182
+ return dq, dk, dv, ds
183
+
184
+ class FlashAttention(torch.autograd.Function):
185
+ @staticmethod
186
+ def forward(ctx, q, k, v, bias, causal, sm_scale):
187
+ Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1]
188
+
189
+ assert Dq == Dk == Dv
190
+ assert Dk in {16, 32, 64, 128}
191
+
192
+ B, H, M, D = q.shape
193
+ N = k.shape[2]
194
+
195
+ if sm_scale is None:
196
+ sm_scale = 1. / math.sqrt(D)
197
+
198
+ config = get_fwd_config(B, H, M, N, D, causal)
199
+ BLOCK_M, BLOCK_N, num_stages, num_warps = config
200
+
201
+ o, L = torch.ops.flasht5.flash_attn_v2_fwd(q, k, v, bias, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages)
202
+
203
+ # autograd context maintenance
204
+ ctx.save_for_backward(q, k, v, bias, o, L)
205
+ ctx.sm_scale = sm_scale
206
+ ctx.causal = causal
207
+
208
+ return o
209
+
210
+ @staticmethod
211
+ def backward(ctx, do, *ignored):
212
+ q, k, v, bias, o, L = ctx.saved_tensors
213
+ sm_scale = ctx.sm_scale
214
+ causal = ctx.causal
215
+
216
+ B, H, M, D = q.shape
217
+ N = k.shape[2]
218
+
219
+ if sm_scale is None:
220
+ sm_scale = 1. / math.sqrt(D)
221
+
222
+ config = get_bwd_config(B, H, M, N, D, causal)
223
+ BLOCK_M, BLOCK_N, num_stages, num_warps = config
224
+
225
+ dq, dk, dv, ds = torch.ops.flasht5.flash_attn_v2_bwd(o, do, q, k, v, bias, L, causal, sm_scale, BLOCK_M, BLOCK_N, num_warps, num_stages)
226
+
227
+ return dq, dk, dv, ds, None, None, None, None
228
+
229
+
230
+ def attention(q, k, v, bias, causal=False, sm_scale=None):
231
+ """
232
+ An implementation of FlashAttention v2(https://arxiv.org/abs/2307.08691).
233
+
234
+ Arguments:
235
+ q(torch.Tensor): The first queries. The shape is (batch_size, nheads, seqlen_q, headdim).
236
+ k(torch.Tensor): The first keys. The shape is (batch_size, nheads, seqlen_k, headdim).
237
+ v(torch.Tensor): The values. The shape is (batch_size, nheads, seqlen_k, headdim).
238
+ causal(bool): Whether causal masking is applied to attention scores before applying softmax.
239
+ sm_scale(float): The scaling of attention scores before applying softmax.
240
+
241
+ Returns:
242
+ out(torch.Tensor): The output. The shape is (batch_size, nheads, seqlen_q, headdim).
243
+ """
244
+ return FlashAttention.apply(q, k, v, bias, causal, sm_scale)
245
+
246
+
247
+ # --------------------------- Forward ---------------------------
248
+ # NOTE: this function can be overwritten at runtime to use your custom config
249
+ def get_fwd_config(B, H, M, N, D, causal):
250
+ if torch.cuda.get_device_capability() == (8, 0):
251
+ if not causal:
252
+ if D <= 64:
253
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
254
+ else:
255
+ if M <= 1024:
256
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4
257
+ else:
258
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
259
+ else:
260
+ if D <= 64:
261
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 4, 4
262
+ else:
263
+ if M <= 1024:
264
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
265
+ else:
266
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
267
+ elif torch.cuda.get_device_capability() == (8, 6):
268
+ if not causal:
269
+ if D <= 64:
270
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
271
+ else:
272
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
273
+ else: # causal
274
+ if D <= 64:
275
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4
276
+ else:
277
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4
278
+ else:
279
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
280
+ return (BLOCK_M, BLOCK_N, num_stages, num_warps)
281
+
282
+
283
+ @triton.jit
284
+ def _fwd_kernel(
285
+ Q, K, V, B, sm_scale,
286
+ L, O,
287
+ stride_qz, stride_qh, stride_qm, stride_qk,
288
+ stride_kz, stride_kh, stride_kn, stride_kk,
289
+ stride_vz, stride_vh, stride_vn, stride_vk,
290
+ stride_oz, stride_oh, stride_om, stride_ok,
291
+ stride_bz, stride_bh, stride_bm, stride_bn,
292
+ Z, H, M, N, P_SEQ,
293
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
294
+ IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
295
+ DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
296
+ HAS_BIAS: tl.constexpr,
297
+ ):
298
+ input_dtype = Q.dtype.element_ty
299
+ # -- grid id --
300
+ start_m = tl.program_id(0)
301
+ off_h = tl.program_id(1)
302
+ off_z = tl.program_id(2)
303
+
304
+ # scale sm_scale by log_2(e) and use
305
+ # 2^x instead of exp in the loop because CSE and LICM
306
+ # don't work as expected with `exp` in the loop
307
+ log2e: tl.constexpr = 1.4426950408889634
308
+
309
+ # offset pointers for (batch, head)
310
+ Q += off_z * stride_qz + off_h * stride_qh
311
+ K += off_z * stride_kz + off_h * stride_kh
312
+ V += off_z * stride_vz + off_h * stride_vh
313
+ O += off_z * stride_oz + off_h * stride_oh
314
+ if HAS_BIAS:
315
+ B += off_z * stride_bz + off_h * stride_bh
316
+ L += (off_z * H + off_h) * M # l's shape is (B, H, M)
317
+
318
+ offs_m_base = tl.arange(0, BLOCK_M)
319
+ offs_m = start_m * BLOCK_M + offs_m_base
320
+ offs_n_base = tl.arange(0, BLOCK_N)
321
+ offs_k = tl.arange(0, BLOCK_DMODEL)
322
+
323
+ # initialize pointers to value-like data
324
+ q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
325
+ o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL)
326
+ l_ptrs = L + offs_m
327
+
328
+ # initialize pointer to m and l, fp32 for accumulators
329
+ m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32)
330
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
331
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
332
+
333
+ # load q
334
+ mask_m = offs_m < M
335
+ if DIVISIBLE_M:
336
+ q = tl.load(q_ptrs, cache_modifier=".cg")
337
+ else:
338
+ q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg")
339
+
340
+ #Dot I trick: to place q in registers, it saves shared memory
341
+ if BLOCK_DMODEL < 128:
342
+ I = tl.where(offs_k[:, None] == offs_k,
343
+ tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype),
344
+ tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype))
345
+ q = tl.dot(q, I).to(input_dtype)
346
+ # else:
347
+ # I = tl.where(offs_m_base[:, None] == offs_m_base,
348
+ # tl.full((BLOCK_M, BLOCK_M), 1.0, dtype=input_dtype),
349
+ # tl.full((BLOCK_M, BLOCK_M), 0.0, dtype=input_dtype))
350
+ # q = tl.dot(I, q).to(input_dtype)
351
+
352
+ # NOTE: Loop-Bound-For-N
353
+ # The indices in m-dimension that this block may access is in `[start_m * BLOCK_M, (start_m + 1) * BLOCK_M)`.
354
+ # According to the rule of causal masking, then max index in n-dimension that this block may access
355
+ # is `P_SEQ + (start_m + 1) * BLOCK_M`.
356
+ # However, the upper bound of index in n-dimension should never exceed the sequence length of k/v(`P_SEQ + N_CTX`).
357
+ # `P_SEQ + (start_m + 1) * BLOCK_M` may be larger than `N`.
358
+ # At this case, there would be illegal memory access when loading k & v tiles
359
+ # if mask_n is not applied for loading(only when `DIVISIBLE_N`` is true).
360
+ # See also https://github.com/FlagOpen/FlagAttention/pull/8
361
+ if IS_CAUSAL:
362
+ hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M)
363
+ if LARGER_M:
364
+ hi = tl.maximum(0, hi)
365
+ else:
366
+ hi = N
367
+
368
+ # loop over k, v and update accumulators
369
+ offs_n_init = offs_n_base
370
+ k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vn) # (BLOCK_DMODEL, BLOCK_N)
371
+ v_ptrs = V + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
372
+ if HAS_BIAS:
373
+ bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n_init[None, :] * stride_bn) # (BLOCK_M, BLOCK_N)
374
+
375
+ for start_n in range(0, hi, BLOCK_N):
376
+ start_n = tl.multiple_of(start_n, BLOCK_N)
377
+ offs_n = start_n + offs_n_base
378
+
379
+ # -- load k, v --
380
+ mask_n = offs_n < N
381
+ if DIVISIBLE_N:
382
+ k = tl.load(k_ptrs, cache_modifier=".cg")
383
+ v = tl.load(v_ptrs, cache_modifier=".cg")
384
+ else:
385
+ k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg")
386
+ v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg")
387
+
388
+ # -- load bias --
389
+ if HAS_BIAS:
390
+ if DIVISIBLE_M and DIVISIBLE_N:
391
+ b = tl.load(bias_ptrs)
392
+ else:
393
+ b = tl.load(bias_ptrs, mask_m[:, None] & mask_n[None, :])
394
+
395
+ # -- compute qk ---
396
+ s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
397
+ s += tl.dot(q, k) * sm_scale
398
+ if HAS_BIAS:
399
+ s += b
400
+
401
+ if not DIVISIBLE_N:
402
+ s = tl.where(mask_n[None, :], s, float("-inf"))
403
+ if IS_CAUSAL:
404
+ causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :]
405
+ s = tl.where(causal_mask, s, float("-inf"))
406
+
407
+ # -- compute scaling constant ---
408
+ m_i_new = tl.maximum(m_i, tl.max(s, 1))
409
+ alpha = tl.math.exp2((m_i - m_i_new)*log2e)
410
+ p = tl.math.exp2((s - m_i_new[:, None])*log2e)
411
+
412
+ # -- scale and update acc: acc *= alpha[:, None]--
413
+ acc *= alpha[:, None]
414
+ acc += tl.dot(p.to(input_dtype), v)
415
+
416
+ # -- update m_i and l_i --
417
+ l_i = l_i * alpha + tl.sum(p, 1)
418
+ m_i = m_i_new
419
+ # update pointers
420
+ k_ptrs += BLOCK_N * stride_kn
421
+ v_ptrs += BLOCK_N * stride_vn
422
+ if HAS_BIAS:
423
+ bias_ptrs += BLOCK_N * stride_bn
424
+
425
+ # write back l & o
426
+ if IS_CAUSAL and LARGER_M:
427
+ is_empty_line = (offs_m + P_SEQ) < 0
428
+ acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None]))
429
+ l = tl.where(is_empty_line, float("-inf"), m_i + tl.log(l_i))
430
+ else:
431
+ acc = acc * (1.0 / l_i[:, None])
432
+ l = m_i + tl.log(l_i) # log(normalizer)
433
+
434
+ if DIVISIBLE_M:
435
+ tl.store(l_ptrs, l, cache_modifier=".cg")
436
+ tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg")
437
+ else:
438
+ tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cg")
439
+ tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=".cg")
440
+
441
+
442
+ # --------------------------- Backward ---------------------------
443
+ # NOTE: this function can be overwritten at runtime to use your custom config
444
+ def get_bwd_config(B, H, M, N, D, causal):
445
+ if torch.cuda.get_device_capability() == (8, 0):
446
+ if not causal:
447
+ BLOCK_M = 128 if D <= 64 else 64
448
+ BLOCK_N = 64
449
+ num_stages = 2
450
+ num_warps = 4
451
+ else:
452
+ BLOCK_M = 64
453
+ BLOCK_N = 64
454
+ num_stages = 3 if D <= 64 else 2
455
+ num_warps = 4
456
+ elif torch.cuda.get_device_capability() == (8, 6): # tune for RTX-3090, device_capability(8, 6)
457
+ if not causal:
458
+ if D <= 64:
459
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4
460
+ else:
461
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 8
462
+ else:
463
+ if D <= 64:
464
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4
465
+ else:
466
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 2, 4
467
+ else:
468
+ BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
469
+ return (BLOCK_M, BLOCK_N, num_stages, num_warps)
470
+
471
+
472
+ @triton.jit
473
+ def _bwd_preprocess(
474
+ Out, DO,
475
+ Delta,
476
+ stride_oz, stride_oh, stride_om, stride_ok,
477
+ stride_doz, stride_doh, stride_dom, stride_dok,
478
+ stride_dz, stride_dh, stride_dm,
479
+ M,
480
+ BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
481
+ DIVISIBLE_M: tl.constexpr,
482
+ ):
483
+ off_h = tl.program_id(1)
484
+ off_z = tl.program_id(2)
485
+ Out += off_z * stride_oz + off_h * stride_oh
486
+ DO += off_z * stride_doz + off_h * stride_doh
487
+ Delta += off_z * stride_dz + off_h * stride_dh
488
+
489
+ # compute (Out * Dout).sum() for vector interpretation
490
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
491
+ off_n = tl.arange(0, D_HEAD)
492
+
493
+ # load
494
+ o_ptrs = Out + off_m[:, None] * stride_om + off_n[None, :] * stride_ok
495
+ do_ptrs = DO + off_m[:, None] * stride_dom + off_n[None, :] * stride_dok
496
+
497
+ if DIVISIBLE_M:
498
+ o = tl.load(o_ptrs).to(tl.float32)
499
+ do = tl.load(do_ptrs).to(tl.float32)
500
+ else:
501
+ mask_m = off_m < M
502
+ o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32)
503
+ do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32)
504
+
505
+ # compute
506
+ delta = tl.sum(o * do, axis=1)
507
+ # write-back
508
+ d_ptrs = Delta + off_m * stride_dm
509
+ if DIVISIBLE_M:
510
+ tl.store(d_ptrs, delta)
511
+ else:
512
+ tl.store(d_ptrs, delta, mask=mask_m)
513
+
514
+
515
+ @triton.jit
516
+ def _bwd_kv_kernel(
517
+ Q, K, V, B, sm_scale, DO,
518
+ DK, DV, DS,
519
+ L,
520
+ D,
521
+ stride_qz, stride_qh, stride_qm, stride_qk,
522
+ stride_kz, stride_kh, stride_kn, stride_kk,
523
+ stride_vz, stride_vh, stride_vn, stride_vk,
524
+ stride_bz, stride_bh, stride_bm, stride_bn,
525
+ stride_doz, stride_doh, stride_dom, stride_dok,
526
+ stride_dkz, stride_dkh, stride_dkn, stride_dkk,
527
+ stride_dvz, stride_dvh, stride_dvn, stride_dvk,
528
+ Z, H, M, N, P_SEQ,
529
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
530
+ CAUSAL: tl.constexpr,
531
+ DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
532
+ HAS_BIAS: tl.constexpr,
533
+ RETURN_DS: tl.constexpr, USE_DS_ATOMIC_ADD: tl.constexpr,
534
+ ):
535
+ input_dtype = Q.dtype.element_ty
536
+ # -- grid id --
537
+ start_n = tl.program_id(0)
538
+ off_h = tl.program_id(1)
539
+ off_z = tl.program_id(2)
540
+ log2e: tl.constexpr = 1.4426950408889634
541
+ qk_scale = sm_scale * log2e
542
+
543
+ # offset pointers for (batch, head)
544
+ Q += off_z * stride_qz + off_h * stride_qh
545
+ K += off_z * stride_kz + off_h * stride_kh
546
+ V += off_z * stride_vz + off_h * stride_vh
547
+ if HAS_BIAS:
548
+ B += off_z * stride_bz + off_h * stride_bh
549
+ DO += off_z * stride_doz + off_h * stride_doh
550
+
551
+ # offset pointers for batch/head
552
+ DK += off_z * stride_dkz + off_h * stride_dkh
553
+ DV += off_z * stride_dvz + off_h * stride_dvh
554
+ if RETURN_DS:
555
+ DS += off_z * stride_bz + off_h * stride_bh
556
+
557
+ # offset pointers for batch/head
558
+ D += (off_z * H + off_h) * M
559
+ L += (off_z * H + off_h) * M
560
+
561
+ if CAUSAL:
562
+ lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0)
563
+ lo = (lo // BLOCK_M) * BLOCK_M
564
+ else:
565
+ lo = 0
566
+
567
+ offs_m_init = lo + tl.arange(0, BLOCK_M)
568
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
569
+ offs_m_base = tl.arange(0, BLOCK_M)
570
+ offs_k = tl.arange(0, BLOCK_DMODEL)
571
+
572
+ # initialize pointers to value-like data
573
+ q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
574
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
575
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
576
+ do_ptrs = DO + (offs_m_init[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL)
577
+
578
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :] * stride_dvk) # (BLOCK_N, BLOCK_DMODEL)
579
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk) # (BLOCK_N, BLOCK_DMODEL)
580
+
581
+ if HAS_BIAS:
582
+ bias_ptrs = B + (offs_m_init[:, None] * stride_bm + offs_n[None, :] * stride_bn)
583
+
584
+ if RETURN_DS:
585
+ ds_ptrs = DS + (offs_m_init[:, None] * stride_bm + offs_n[None, :] * stride_bn)
586
+
587
+ # k and v stay in SRAM throughout
588
+ mask_n = offs_n < N
589
+ if DIVISIBLE_N:
590
+ v = tl.load(v_ptrs)
591
+ k = tl.load(k_ptrs)
592
+ else:
593
+ v = tl.load(v_ptrs, mask=mask_n[:, None])
594
+ k = tl.load(k_ptrs, mask=mask_n[:, None])
595
+
596
+ # initialize dk amd dv
597
+ dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
598
+ dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
599
+
600
+ # loop over a col
601
+ for start_m in range(lo, M, BLOCK_M):
602
+ start_m = tl.multiple_of(start_m, BLOCK_M)
603
+ offs_m = start_m + offs_m_base
604
+ causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N)
605
+
606
+ # load q1, k1, q2, k2, v, do on-chip
607
+ mask_m = offs_m < M
608
+ if DIVISIBLE_M:
609
+ q = tl.load(q_ptrs)
610
+ else:
611
+ valid_mask = mask_m[:, None] # & mask_n
612
+ q = tl.load(q_ptrs, mask=mask_m[:, None])
613
+
614
+ # load bias
615
+ if HAS_BIAS:
616
+ if DIVISIBLE_M and DIVISIBLE_N:
617
+ b = tl.load(bias_ptrs)
618
+ else:
619
+ b = tl.load(bias_ptrs, mask=mask_m[:, None] & mask_n[None, :])
620
+
621
+ # recompute p = softmax(qk * sm_scale, dim=-1)
622
+ s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
623
+ s += tl.dot(q, tl.trans(k)) * sm_scale
624
+
625
+ if HAS_BIAS:
626
+ s += b
627
+
628
+ # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd)
629
+ # So masking on s is not needed.
630
+ # s = tl.where(valid_mask, s , float("-inf"))
631
+ # if CAUSAL:
632
+ # s = tl.where(causal_mask, s, float("-inf"))
633
+
634
+ # -- recompute p ---
635
+ if DIVISIBLE_M:
636
+ l = tl.load(L + offs_m)
637
+ else:
638
+ l = tl.load(L + offs_m, mask=mask_m)
639
+ p = tl.math.exp2((s - l[:, None])*log2e) # (BLOCK_M, BLOCK_N)
640
+
641
+ if not DIVISIBLE_M:
642
+ p = tl.where(valid_mask, p, 0.0)
643
+ if CAUSAL:
644
+ p = tl.where(causal_mask, p, 0.0)
645
+
646
+ # compute dv = dot(p, do)
647
+ if DIVISIBLE_M:
648
+ do = tl.load(do_ptrs)
649
+ else:
650
+ do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL)
651
+ dv += tl.dot(tl.trans(p.to(do.dtype)), do) # (BLOCK_N, BLOCK_DMODEL) # still correct
652
+
653
+ # compute dp = dot(v, do)
654
+ if DIVISIBLE_M:
655
+ delta = tl.load(D + offs_m)
656
+ else:
657
+ delta = tl.load(D + offs_m, mask=mask_m)
658
+ dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
659
+ dp += tl.dot(do, tl.trans(v))
660
+
661
+ # compute ds = p * (dp - delta[:, None])
662
+ ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N)
663
+
664
+ if not DIVISIBLE_M:
665
+ ds = tl.where(valid_mask, ds, 0.0)
666
+ if CAUSAL:
667
+ ds = tl.where(causal_mask, ds, 0.0)
668
+ ds = ds.to(input_dtype)
669
+
670
+ if RETURN_DS:
671
+ if DIVISIBLE_M and DIVISIBLE_N:
672
+ if USE_DS_ATOMIC_ADD:
673
+ tl.atomic_add(ds_ptrs, ds)
674
+ else:
675
+ tl.store(ds_ptrs, ds)
676
+ else:
677
+ if USE_DS_ATOMIC_ADD:
678
+ tl.atomic_add(ds_ptrs, ds, mask=mask_m[:, None] & mask_n[None, :])
679
+ else:
680
+ tl.store(ds_ptrs, ds, mask=mask_m[:, None] & mask_n[None, :])
681
+
682
+ # compute dk = dot(ds.T, q) masking
683
+ dk += tl.dot(tl.trans(ds), q)
684
+
685
+ # increment pointers
686
+ q_ptrs += BLOCK_M * stride_qm
687
+ do_ptrs += BLOCK_M * stride_dom
688
+ if HAS_BIAS:
689
+ bias_ptrs += BLOCK_M * stride_bm
690
+ if RETURN_DS:
691
+ ds_ptrs += BLOCK_M * stride_bm
692
+
693
+ dk *= sm_scale
694
+ if DIVISIBLE_N:
695
+ tl.store(dk_ptrs, dk.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL)
696
+ tl.store(dv_ptrs, dv.to(input_dtype)) # (BLOCK_N, BLOCK_DMODEL,)
697
+ else:
698
+ tl.store(dk_ptrs, dk.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL)
699
+ tl.store(dv_ptrs, dv.to(input_dtype), mask=mask_n[:, None]) # (BLOCK_N, BLOCK_DMODEL,)
700
+
701
+
702
+ @triton.jit
703
+ def _bwd_q_kernel(
704
+ Q, K, V, B, sm_scale, DO,
705
+ DQ,
706
+ L,
707
+ D,
708
+ stride_qz, stride_qh, stride_qm, stride_qk,
709
+ stride_kz, stride_kh, stride_kn, stride_kk,
710
+ stride_vz, stride_vh, stride_vn, stride_vk,
711
+ stride_bz, stride_bh, stride_bm, stride_bn,
712
+ stride_doz, stride_doh, stride_dom, stride_dok,
713
+ stride_dqz, stride_dqh, stride_dqm, stride_dqk,
714
+ Z, H, M, N, P_SEQ,
715
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
716
+ CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
717
+ DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
718
+ HAS_BIAS: tl.constexpr
719
+ ):
720
+ input_dtype = Q.dtype.element_ty
721
+ # -- grid id --
722
+ start_m = tl.program_id(0)
723
+ off_h = tl.program_id(1)
724
+ off_z = tl.program_id(2)
725
+
726
+ # scale sm_scale by log_2(e) and use
727
+ # 2^x instead of exp in the loop because CSE and LICM
728
+ # don't work as expected with `exp` in the loop
729
+ log2e: tl.constexpr = 1.4426950408889634
730
+
731
+ # offset pointers for (batch, head)
732
+ Q += off_z * stride_qz + off_h * stride_qh
733
+ K += off_z * stride_kz + off_h * stride_kh
734
+ V += off_z * stride_vz + off_h * stride_vh
735
+ if HAS_BIAS:
736
+ B += off_z * stride_bz + off_h * stride_bh
737
+ DO += off_z * stride_doz + off_h * stride_doh
738
+ D += (off_z * H + off_h) * M
739
+ L += (off_z * H + off_h) * M
740
+
741
+ # offset pointers for batch/head
742
+ DQ += off_z * stride_dqz + off_h * stride_dqh
743
+
744
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
745
+ offs_n_base = tl.arange(0, BLOCK_N)
746
+ offs_n_init = offs_n_base
747
+ offs_k = tl.arange(0, BLOCK_DMODEL)
748
+
749
+ # initialize pointers to value-like data
750
+ q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL)
751
+ k_ptrs = K + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL)
752
+ v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL)
753
+
754
+ if HAS_BIAS:
755
+ bias_ptrs = B + (offs_m[:, None] * stride_bm + offs_n_init[None, :] * stride_bn)
756
+
757
+ dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk) # (BLOCK_M, BLOCK_DMODEL)
758
+ do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (BLOCK_M, BLOCK_DMODEL)
759
+
760
+ # pointer to row-wise quantities in value-like data
761
+ d_ptrs = D + offs_m
762
+ l_ptrs = L + offs_m
763
+
764
+ # load q: it will stay in SRAM throughout
765
+ mask_m = offs_m < M
766
+ if DIVISIBLE_M:
767
+ q = tl.load(q_ptrs)
768
+ do = tl.load(do_ptrs)
769
+ delta = tl.load(d_ptrs)
770
+ l = tl.load(l_ptrs)
771
+ else:
772
+ q = tl.load(q_ptrs, mask=mask_m[:, None])
773
+ do = tl.load(do_ptrs, mask=mask_m[:, None])
774
+ delta = tl.load(d_ptrs, mask=mask_m)
775
+ l = tl.load(l_ptrs, mask=mask_m)
776
+
777
+ # initialize dq
778
+ dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
779
+
780
+ # loop over k, v and update accumulator
781
+ # see note "Loop-Bound-For-N"
782
+ if CAUSAL:
783
+ hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M)
784
+ if LARGER_M:
785
+ hi = tl.maximum(0, hi)
786
+ else:
787
+ hi = N
788
+
789
+ # loop over a row
790
+ for start_n in range(0, hi, BLOCK_N):
791
+ offs_n = start_n + offs_n_base
792
+
793
+ # load k1, k2, v on chip
794
+ mask_n = offs_n < N
795
+ if DIVISIBLE_N:
796
+ v = tl.load(v_ptrs)
797
+ k = tl.load(k_ptrs)
798
+ else:
799
+ v = tl.load(v_ptrs, mask=mask_n[:, None])
800
+ k = tl.load(k_ptrs, mask=mask_n[:, None])
801
+
802
+ # load bias
803
+ if HAS_BIAS:
804
+ if DIVISIBLE_M and DIVISIBLE_N:
805
+ b = tl.load(bias_ptrs)
806
+ else:
807
+ b = tl.load(bias_ptrs, mask=mask_m[:, None] & mask_n[None, :])
808
+
809
+ # recompute p = softmax(qk * sm_scale, dim=-1)
810
+ if not DIVISIBLE_N:
811
+ valid_mask = mask_n # & mask_m[:, None]
812
+ if CAUSAL:
813
+ causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) # (BLOCK_M, BLOCK_N)
814
+
815
+ s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
816
+ s += tl.dot(q, tl.trans(k)) * sm_scale
817
+ if HAS_BIAS:
818
+ s += b
819
+
820
+ # NOTE: since softmax in backward is pointwise, the normalizer has been saved in fwd)
821
+ # So masking on s is not needed.
822
+ # if CAUSAL:
823
+ # s = tl.where(causal_mask & valid_mask, s, float("-inf"))
824
+ # else:
825
+ # s = tl.where(valid_mask, s, float("-inf"))
826
+ p = tl.math.exp2((s - l[:, None])*log2e) # (BLOCK_M, BLOCK_N)
827
+
828
+ # compute dp = dot(v, do)
829
+ dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
830
+ dp += tl.dot(do.to(input_dtype), tl.trans(v))
831
+ # no need to mask dp
832
+ # if CAUSAL:
833
+ # dp = tl.where(causal_mask & valid_mask, dp, 0.0)
834
+ # else:
835
+ # dp = tl.where(valid_mask, dp, 0.0)
836
+
837
+ # compute ds = p * (dp - delta[:, None])
838
+ # move scale out to dq at last
839
+ ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N)
840
+
841
+ # mask ds to ensure no small values
842
+ if not DIVISIBLE_N:
843
+ ds = tl.where(valid_mask, ds, 0.0)
844
+ if CAUSAL:
845
+ ds = tl.where(causal_mask, ds, 0.0)
846
+
847
+ dq += tl.dot(ds.to(input_dtype), k)
848
+
849
+ # increment pointers
850
+ k_ptrs += BLOCK_N * stride_kn
851
+ v_ptrs += BLOCK_N * stride_vn
852
+ if HAS_BIAS:
853
+ bias_ptrs += BLOCK_N * stride_bn
854
+
855
+ dq *= sm_scale
856
+ if DIVISIBLE_M:
857
+ tl.store(dq_ptrs, dq.to(input_dtype))
858
+ else:
859
+ tl.store(dq_ptrs, dq.to(input_dtype), mask=mask_m[:, None])
gated_mlp.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ from torch.cuda.amp import custom_bwd, custom_fwd
7
+
8
+ def to_tl_dtype(input):
9
+ if input == torch.float32:
10
+ return tl.float32
11
+ elif input == torch.float16:
12
+ return tl.float16
13
+ elif input == torch.bfloat16:
14
+ return tl.bfloat16
15
+ elif input == torch.int64:
16
+ return tl.int64
17
+ else:
18
+ raise ValueError(f"Unable to convert the given input: '{input}'.")
19
+
20
+ ## Activation function from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py
21
+
22
+ _kAlpha = math.sqrt(2.0 / math.pi)
23
+
24
+ def gelu_torch(x):
25
+ """
26
+ GeLU_ activation - Gaussian error linear unit
27
+
28
+ .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
29
+ """
30
+ return 0.5 * x * (1 + torch.tanh(_kAlpha * (x + 0.044715 * x * x * x)))
31
+
32
+ def gelu_grad_torch(x):
33
+ # CREDITS: Fast implementation proposed in
34
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
35
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
36
+ return 0.5 * x * (
37
+ (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
38
+ ) + 0.5 * (1 + tanh_out)
39
+
40
+ # ReLU
41
+ @triton.jit
42
+ def tanh(x):
43
+ # Tanh is just a scaled sigmoid
44
+ return 2 * tl.sigmoid(2 * x) - 1
45
+
46
+ @triton.jit
47
+ def relu(x):
48
+ """
49
+ ReLU_ activation function
50
+
51
+ .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
52
+ """
53
+ return tl.where(x >= 0, x, 0.0)
54
+
55
+
56
+ @triton.jit
57
+ def relu_grad(x):
58
+ # ReLU is different from other activations
59
+ # in that it does not require the input to retrospectively compute its gradient
60
+ # here the input is the downstream gradient, and we return the upstream gradient directly
61
+ return tl.where(x >= 0, 1.0, 0.0)
62
+
63
+ @triton.jit
64
+ def gelu(x):
65
+ """
66
+ GeLU_ activation - Gaussian error linear unit
67
+
68
+ .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
69
+ """
70
+ return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x)))
71
+
72
+
73
+ @triton.jit
74
+ def gelu_grad(x):
75
+ # CREDITS: Fast implementation proposed in
76
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
77
+ tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
78
+ return 0.5 * x * (
79
+ (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
80
+ ) + 0.5 * (1 + tanh_out)
81
+
82
+
83
+ @triton.jit
84
+ def gated_matmul_fwd(
85
+ # Pointers to matrices
86
+ out, input, w1, w2,
87
+ act_input_1, act_input_2,
88
+ # Matrix dimensions
89
+ M, N, K,
90
+ stride_om,
91
+ stride_im,
92
+ stride_wn,
93
+ # Meta-parameters
94
+ dtype: tl.constexpr,
95
+ BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr,
96
+ BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
97
+ USE_GELU: tl.constexpr,
98
+ SAVE_ACTIVATION_INPUTS: tl.constexpr,
99
+ IS_EVEN_MNK: tl.constexpr
100
+ ):
101
+
102
+ """
103
+ Kernel for computing Out = activation(A x W + C)
104
+
105
+ - Input has shape (M, K)
106
+ - Weight 1 has shape (K, N)
107
+ - Weight 2 has shape (K, N)
108
+ - Output has shape (M, N)
109
+
110
+ """
111
+
112
+ pid = tl.program_id(0)
113
+
114
+ num_pid_m = tl.cdiv(M, BLOCK_M) # number of program ids along the M axis
115
+ num_pid_n = tl.cdiv(N, BLOCK_N) # number of programs ids along the N axis
116
+
117
+ num_pid_in_group = GROUP_M * num_pid_n # number of programs in group
118
+ group_id = pid // num_pid_in_group # id of the group this program is in
119
+ first_pid_m = group_id * GROUP_M # row-id of the first program in the group
120
+ GROUP_M = min(
121
+ num_pid_m - first_pid_m, GROUP_M
122
+ ) # if `num_pid_m` isn't divisible by `GROUP_M`, the last group is smaller
123
+
124
+ # *within groups*, programs are ordered in a column-major order
125
+ # row-id /col-id of the program in the *launch grid*
126
+ pid_m = first_pid_m + (pid % GROUP_M)
127
+ pid_n = (pid % num_pid_in_group) // GROUP_M
128
+
129
+ input_block_ptr = tl.make_block_ptr(
130
+ base=input,
131
+ shape=(M, K),
132
+ strides=(stride_im, 1),
133
+ offsets=(pid_m * BLOCK_M, 0),
134
+ block_shape=(BLOCK_M, BLOCK_K),
135
+ order=(1, 0),
136
+ )
137
+
138
+ w1_block_ptr = tl.make_block_ptr(
139
+ base=w1,
140
+ shape=(K, N),
141
+ strides=(1, stride_wn),
142
+ offsets=(0, pid_n * BLOCK_N),
143
+ block_shape=(BLOCK_K, BLOCK_N),
144
+ order=(0, 1),
145
+ )
146
+
147
+ w2_block_ptr = tl.make_block_ptr(
148
+ base=w2,
149
+ shape=(K, N),
150
+ strides=(1, stride_wn),
151
+ offsets=(0, pid_n * BLOCK_N),
152
+ block_shape=(BLOCK_K, BLOCK_N),
153
+ order=(0, 1),
154
+ )
155
+
156
+ # initialize and iteratively update accumulator
157
+ acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
158
+ acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
159
+
160
+ for i in range(0, K, BLOCK_K):
161
+
162
+ if IS_EVEN_MNK:
163
+ x = tl.load(input_block_ptr)
164
+ w1_blk = tl.load(w1_block_ptr)
165
+ w2_blk = tl.load(w2_block_ptr)
166
+ else:
167
+ x = tl.load(input_block_ptr, boundary_check=(0, 1))
168
+ w1_blk = tl.load(w1_block_ptr, boundary_check=(0, 1))
169
+ w2_blk = tl.load(w2_block_ptr, boundary_check=(0, 1))
170
+
171
+ acc1 += tl.dot(x, w1_blk)
172
+ acc2 += tl.dot(x, w2_blk)
173
+
174
+ input_block_ptr = tl.advance(input_block_ptr, (0, BLOCK_K))
175
+ w1_block_ptr = tl.advance(w1_block_ptr, (BLOCK_K, 0))
176
+ w2_block_ptr = tl.advance(w2_block_ptr, (BLOCK_K, 0))
177
+
178
+ if SAVE_ACTIVATION_INPUTS:
179
+ act_in_1_ptrs = tl.make_block_ptr(
180
+ base=act_input_1,
181
+ shape=(M, N),
182
+ strides=(stride_om, 1),
183
+ offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
184
+ block_shape=(BLOCK_M, BLOCK_N),
185
+ order=(1, 0),
186
+ )
187
+
188
+ act_in_2_ptrs = tl.make_block_ptr(
189
+ base=act_input_2,
190
+ shape=(M, N),
191
+ strides=(stride_om, 1),
192
+ offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
193
+ block_shape=(BLOCK_M, BLOCK_N),
194
+ order=(1, 0),
195
+ )
196
+
197
+ if IS_EVEN_MNK:
198
+ tl.store(act_in_1_ptrs, acc1.to(dtype))
199
+ tl.store(act_in_2_ptrs, acc2.to(dtype))
200
+ else:
201
+ tl.store(act_in_1_ptrs, acc1.to(dtype), boundary_check=(0, 1))
202
+ tl.store(act_in_2_ptrs, acc2.to(dtype), boundary_check=(0, 1))
203
+
204
+ if USE_GELU:
205
+ acc1 = gelu(acc1)
206
+ else:
207
+ acc1 = relu(acc1)
208
+
209
+ # gating
210
+ acc = acc1 * acc2
211
+
212
+ # write back result
213
+ out_ptrs = tl.make_block_ptr(
214
+ base=out,
215
+ shape=(M, N),
216
+ strides=(stride_om, 1),
217
+ offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
218
+ block_shape=(BLOCK_M, BLOCK_N),
219
+ order=(1, 0),
220
+ )
221
+
222
+ if IS_EVEN_MNK:
223
+ tl.store(out_ptrs, acc.to(dtype))
224
+ else:
225
+ tl.store(out_ptrs, acc.to(dtype), boundary_check=(0, 1))
226
+
227
+ @triton.jit
228
+ def gated_matmul_bwd_ygrad(
229
+ dout,
230
+ y1_grad, y2_grad,
231
+ act_input_1, act_input_2,
232
+ M, N,
233
+ stride_dom,
234
+ # Meta-parameters
235
+ dtype: tl.constexpr,
236
+ BLOCK_M: tl.constexpr,
237
+ BLOCK_N: tl.constexpr,
238
+ USE_GELU: tl.constexpr,
239
+ IS_EVEN_MNK: tl.constexpr):
240
+
241
+ """
242
+ Kernel for backward gated MLP
243
+
244
+ Ref :
245
+ y2_grad = torch.mul(gelu(x @ w1), dout)
246
+ y1_grad = torch.mul(gelu_grad(x @ w1) * (x @ w2), dout)
247
+ """
248
+
249
+ pid_m = tl.program_id(0)
250
+ pid_n = tl.program_id(1)
251
+
252
+ # block pointers
253
+ actin_1_block_ptr = tl.make_block_ptr(
254
+ base=act_input_1,
255
+ shape=(M, N),
256
+ strides=(stride_dom, 1),
257
+ offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
258
+ block_shape=(BLOCK_M, BLOCK_N),
259
+ order=(1, 0),
260
+ )
261
+
262
+ actin_2_block_ptr = tl.make_block_ptr(
263
+ base=act_input_2,
264
+ shape=(M, N),
265
+ strides=(stride_dom, 1),
266
+ offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
267
+ block_shape=(BLOCK_M, BLOCK_N),
268
+ order=(1, 0),
269
+ )
270
+
271
+ dout_block_ptr = tl.make_block_ptr(
272
+ base=dout,
273
+ shape=(M, N),
274
+ strides=(stride_dom, 1),
275
+ offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
276
+ block_shape=(BLOCK_M, BLOCK_N),
277
+ order=(1, 0),
278
+ )
279
+
280
+ if IS_EVEN_MNK:
281
+ dout_blk = tl.load(dout_block_ptr)
282
+ actin_1_blk = tl.load(actin_1_block_ptr)
283
+ actin_2_blk = tl.load(actin_2_block_ptr)
284
+ else:
285
+ dout_blk = tl.load(dout_block_ptr, boundary_check=(0, 1))
286
+ actin_1_blk = tl.load(actin_1_block_ptr, boundary_check=(0, 1))
287
+ actin_2_blk = tl.load(actin_2_block_ptr, boundary_check=(0, 1))
288
+
289
+ if USE_GELU:
290
+ actin_act = gelu(actin_1_blk)
291
+ actin_act_grad = gelu_grad(actin_1_blk)
292
+ else:
293
+ actin_act = relu(actin_1_blk)
294
+ actin_act_grad = relu_grad(actin_1_blk)
295
+
296
+ actin_act *= dout_blk # y2_grad
297
+ actin_act_grad *= actin_2_blk
298
+ actin_act_grad *= dout_blk # y1_grad
299
+
300
+ y1_grad_ptrs = tl.make_block_ptr(
301
+ base=y1_grad,
302
+ shape=(M, N),
303
+ strides=(stride_dom, 1),
304
+ offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
305
+ block_shape=(BLOCK_M, BLOCK_N),
306
+ order=(1, 0),
307
+ )
308
+
309
+ y2_grad_ptrs = tl.make_block_ptr(
310
+ base=y2_grad,
311
+ shape=(M, N),
312
+ strides=(stride_dom, 1),
313
+ offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
314
+ block_shape=(BLOCK_M, BLOCK_N),
315
+ order=(1, 0),
316
+ )
317
+
318
+ if IS_EVEN_MNK:
319
+ tl.store(y1_grad_ptrs, actin_act_grad.to(dtype))
320
+ tl.store(y2_grad_ptrs, actin_act.to(dtype))
321
+ else:
322
+ tl.store(y1_grad_ptrs, actin_act_grad.to(dtype), boundary_check=(0, 1))
323
+ tl.store(y2_grad_ptrs, actin_act.to(dtype), boundary_check=(0, 1))
324
+
325
+
326
+ @triton.jit
327
+ def gated_matmul_bwd_input(
328
+ # Pointers to matrices
329
+ w1, w2, # weights inputs
330
+ y1_grad, y2_grad, # partial computation
331
+ din, # outputs
332
+ # Matrix dimensions
333
+ M, N, K,
334
+ stride_dom, stride_im,
335
+ stride_wn,
336
+ # Meta-parameters
337
+ dtype: tl.constexpr,
338
+ BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr,
339
+ BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
340
+ IS_EVEN_MNK: tl.constexpr
341
+ ):
342
+
343
+ """
344
+ Kernel for backward gated MLP
345
+ We group along the N axis
346
+
347
+ Ref :
348
+ x_grad = torch.matmul(y2_grad, w2.t()) + torch.matmul(y1_grad, w1.t())
349
+ """
350
+
351
+ pid = tl.program_id(0)
352
+
353
+ num_pid_m = tl.cdiv(M, BLOCK_M) # number of program ids along the M axis
354
+ num_pid_k = tl.cdiv(K, BLOCK_K) # number of programs ids along the K axis
355
+
356
+ num_pid_in_group = GROUP_M * num_pid_k # number of programs in group
357
+ group_id = pid // num_pid_in_group # id of the group this program is in
358
+ first_pid_m = group_id * GROUP_M # row-id of the first program in the group
359
+ GROUP_M = min(
360
+ num_pid_m - first_pid_m, GROUP_M
361
+ ) # if `num_pid_m` isn't divisible by `GROUP_M`, the last group is smaller
362
+
363
+ # *within groups*, programs are ordered in a column-major order
364
+ # row-id /col-id of the program in the *launch grid*
365
+ pid_m = first_pid_m + (pid % GROUP_M)
366
+ pid_k = (pid % num_pid_in_group) // GROUP_M
367
+
368
+ y1_grad_block_ptr = tl.make_block_ptr(
369
+ base=y1_grad,
370
+ shape=(M, N),
371
+ strides=(stride_dom, 1),
372
+ offsets=(pid_m * BLOCK_M, 0),
373
+ block_shape=(BLOCK_M, BLOCK_N),
374
+ order=(1, 0),
375
+ )
376
+
377
+ y2_grad_block_ptr = tl.make_block_ptr(
378
+ base=y2_grad,
379
+ shape=(M, N),
380
+ strides=(stride_dom, 1),
381
+ offsets=(pid_m * BLOCK_M, 0),
382
+ block_shape=(BLOCK_M, BLOCK_N),
383
+ order=(1, 0),
384
+ )
385
+
386
+ w1_block_ptr = tl.make_block_ptr(
387
+ base=w1,
388
+ shape=(N, K),
389
+ strides=(stride_wn, 1),
390
+ offsets=(0, pid_k * BLOCK_K),
391
+ block_shape=(BLOCK_N, BLOCK_K),
392
+ order=(1, 0),
393
+ )
394
+
395
+ w2_block_ptr = tl.make_block_ptr(
396
+ base=w2,
397
+ shape=(N, K),
398
+ strides=(stride_wn, 1),
399
+ offsets=(0, pid_k * BLOCK_K),
400
+ block_shape=(BLOCK_N, BLOCK_K),
401
+ order=(1, 0),
402
+ )
403
+
404
+ # initialize and iteratively update accumulator
405
+ acc_dx = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
406
+
407
+ for i in range(0, N, BLOCK_N):
408
+
409
+ if IS_EVEN_MNK:
410
+ w1_blk = tl.load(w1_block_ptr)
411
+ w2_blk = tl.load(w2_block_ptr)
412
+ y1_grad_blk = tl.load(y1_grad_block_ptr)
413
+ y2_grad_blk = tl.load(y2_grad_block_ptr)
414
+ else:
415
+ w1_blk = tl.load(w1_block_ptr, boundary_check=(0, 1))
416
+ w2_blk = tl.load(w2_block_ptr, boundary_check=(0, 1))
417
+ y1_grad_blk = tl.load(y1_grad_block_ptr, boundary_check=(0, 1))
418
+ y2_grad_blk = tl.load(y2_grad_block_ptr, boundary_check=(0, 1))
419
+
420
+ acc_dx += tl.dot(y2_grad_blk, w2_blk)
421
+ acc_dx += tl.dot(y1_grad_blk, w1_blk)
422
+
423
+ w1_block_ptr = tl.advance(w1_block_ptr, (BLOCK_N, 0))
424
+ w2_block_ptr = tl.advance(w2_block_ptr, (BLOCK_N, 0))
425
+ y1_grad_block_ptr = tl.advance(y1_grad_block_ptr, (0, BLOCK_N))
426
+ y2_grad_block_ptr = tl.advance(y2_grad_block_ptr, (0, BLOCK_N))
427
+
428
+ # write back result
429
+ dx_ptrs = tl.make_block_ptr(
430
+ base=din,
431
+ shape=(M, K),
432
+ strides=(stride_im, 1),
433
+ offsets=(pid_m * BLOCK_M, pid_k * BLOCK_K),
434
+ block_shape=(BLOCK_M, BLOCK_K),
435
+ order=(1, 0),
436
+ )
437
+
438
+ if IS_EVEN_MNK:
439
+ tl.store(dx_ptrs, acc_dx.to(dtype))
440
+ else:
441
+ tl.store(dx_ptrs, acc_dx.to(dtype), boundary_check=(0, 1))
442
+
443
+
444
+ @triton.jit
445
+ def gated_matmul_bwd_weights(
446
+ # Pointers to matrices
447
+ input,
448
+ y1_grad, y2_grad, # precomputations
449
+ dw1, dw2, # outputs
450
+ # Matrix dimensions
451
+ M, N, K,
452
+ stride_dom, stride_im,
453
+ stride_wn,
454
+ # Meta-parameters
455
+ dtype: tl.constexpr,
456
+ BLOCK_M: tl.constexpr, GROUP_N: tl.constexpr,
457
+ BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
458
+ IS_EVEN_MNK: tl.constexpr
459
+ ):
460
+
461
+ """
462
+ Kernel for backward gated MLP
463
+ We group along the M axis
464
+
465
+ Ref :
466
+ w1_grad = torch.matmul(y1_grad.t(), x)
467
+ w2_grad = torch.matmul(y2_grad.t(), x)
468
+ """
469
+
470
+ pid = tl.program_id(0)
471
+
472
+ num_pid_n = tl.cdiv(N, BLOCK_N) # number of program ids along the M axis
473
+ num_pid_k = tl.cdiv(K, BLOCK_K) # number of programs ids along the K axis
474
+
475
+ num_pid_in_group = GROUP_N * num_pid_k # number of programs in group
476
+ group_id = pid // num_pid_in_group # id of the group this program is in
477
+ first_pid_n = group_id * GROUP_N # row-id of the first program in the group
478
+ GROUP_N = min(
479
+ num_pid_n - first_pid_n, GROUP_N
480
+ ) # if `num_pid_m` isn't divisible by `GROUP_M`, the last group is smaller
481
+
482
+ # *within groups*, programs are ordered in a column-major order
483
+ # row-id /col-id of the program in the *launch grid*
484
+ pid_n = first_pid_n + (pid % GROUP_N)
485
+ pid_k = (pid % num_pid_in_group) // GROUP_N
486
+
487
+ # block pointers
488
+ y1_grad_block_ptr = tl.make_block_ptr(
489
+ base=y1_grad,
490
+ shape=(N, M),
491
+ strides=(1, stride_dom),
492
+ offsets=(pid_n * BLOCK_N, 0),
493
+ block_shape=(BLOCK_N, BLOCK_M),
494
+ order=(0, 1),
495
+ )
496
+
497
+ y2_grad_block_ptr = tl.make_block_ptr(
498
+ base=y2_grad,
499
+ shape=(N, M),
500
+ strides=(1, stride_dom),
501
+ offsets=(pid_n * BLOCK_N, 0),
502
+ block_shape=(BLOCK_N, BLOCK_M),
503
+ order=(0, 1),
504
+ )
505
+
506
+ input_block_ptr = tl.make_block_ptr(
507
+ base=input,
508
+ shape=(M, K),
509
+ strides=(stride_im, 1),
510
+ offsets=(0, pid_k * BLOCK_K),
511
+ block_shape=(BLOCK_M, BLOCK_K),
512
+ order=(1, 0),
513
+ )
514
+
515
+ ref = tl.load(input + tl.arange(0, 1))
516
+
517
+ # initialize and iteratively update accumulator
518
+ acc_dw1 = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
519
+ acc_dw2 = tl.zeros((BLOCK_N, BLOCK_K), dtype=tl.float32)
520
+
521
+ for i in range(0, M, BLOCK_M):
522
+
523
+ if IS_EVEN_MNK:
524
+ y1grad_blk = tl.load(y1_grad_block_ptr)
525
+ y2grad_blk = tl.load(y2_grad_block_ptr)
526
+ x = tl.load(input_block_ptr)
527
+ else:
528
+ y1grad_blk = tl.load(y1_grad_block_ptr, boundary_check=(0, 1))
529
+ y2grad_blk = tl.load(y2_grad_block_ptr, boundary_check=(0, 1))
530
+ x = tl.load(input_block_ptr, boundary_check=(0, 1))
531
+
532
+ acc_dw1 += tl.dot(y1grad_blk, x)
533
+ acc_dw2 += tl.dot(y2grad_blk, x)
534
+
535
+ y1_grad_block_ptr = tl.advance(y1_grad_block_ptr, (0, BLOCK_M))
536
+ y2_grad_block_ptr = tl.advance(y2_grad_block_ptr, (0, BLOCK_M))
537
+ input_block_ptr = tl.advance(input_block_ptr, (BLOCK_M, 0))
538
+
539
+ # write back result
540
+ dw1_ptrs = tl.make_block_ptr(
541
+ base=dw1,
542
+ shape=(N, K),
543
+ strides=(stride_wn, 1),
544
+ offsets=(pid_n * BLOCK_N, pid_k * BLOCK_K),
545
+ block_shape=(BLOCK_N, BLOCK_K),
546
+ order=(1, 0),
547
+ )
548
+
549
+ dw2_ptrs = tl.make_block_ptr(
550
+ base=dw2,
551
+ shape=(N, K),
552
+ strides=(stride_wn, 1),
553
+ offsets=(pid_n * BLOCK_N, pid_k * BLOCK_K),
554
+ block_shape=(BLOCK_N, BLOCK_K),
555
+ order=(1, 0),
556
+ )
557
+
558
+ if IS_EVEN_MNK:
559
+ tl.store(dw1_ptrs, acc_dw1.to(dtype))
560
+ tl.store(dw2_ptrs, acc_dw2.to(dtype))
561
+ else:
562
+ tl.store(dw1_ptrs, acc_dw1.to(dtype), boundary_check=(0, 1))
563
+ tl.store(dw2_ptrs, acc_dw2.to(dtype), boundary_check=(0, 1))
564
+
565
+
566
+ class GatedMLP(torch.autograd.Function):
567
+ @staticmethod
568
+ @custom_fwd
569
+ def forward(ctx, x, w1, w2, use_gelu=True):
570
+
571
+ BLOCK_M = 128
572
+ BLOCK_N = 64
573
+ BLOCK_K = 64
574
+ GROUP_M = 8
575
+
576
+ SAVE_ACT_IN = x.requires_grad
577
+
578
+ if torch.is_autocast_enabled():
579
+ x = x.to(torch.get_autocast_gpu_dtype())
580
+ w1 = w1.to(torch.get_autocast_gpu_dtype())
581
+ w2 = w2.to(torch.get_autocast_gpu_dtype())
582
+
583
+ assert x.is_contiguous()
584
+ assert w1.is_contiguous()
585
+ assert w2.is_contiguous()
586
+ assert w1.shape == w2.shape
587
+ assert x.shape[2] == w1.shape[1]
588
+ assert x.shape[2] == w2.shape[1]
589
+
590
+ x_ = x if x.ndim == 2 else x.flatten(0, -2)
591
+
592
+ M, K = x_.shape
593
+ N, K = w1.shape
594
+
595
+ IS_EVEN_MNK = ((M % BLOCK_M) == 0) and ((N % BLOCK_N) == 0) and ((K % BLOCK_K) == 0)
596
+
597
+ out = torch.empty((M, N), device=x.device, dtype=x.dtype)
598
+
599
+ tl_dtype = to_tl_dtype(x.dtype)
600
+
601
+ act_input_1, act_input_2 = None, None
602
+ if SAVE_ACT_IN:
603
+ act_input_1 = torch.empty_like(out)
604
+ act_input_2 = torch.empty_like(out)
605
+
606
+ grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
607
+ gated_matmul_fwd[grid](
608
+ out,
609
+ x_, w1, w2,
610
+ act_input_1, act_input_2,
611
+ M, N, K,
612
+ out.stride(0), x_.stride(0),
613
+ w1.stride(0),
614
+ tl_dtype,
615
+ BLOCK_M, GROUP_M, BLOCK_N, BLOCK_K,
616
+ use_gelu,
617
+ SAVE_ACT_IN,
618
+ IS_EVEN_MNK,
619
+ )
620
+
621
+ ctx.save_for_backward(x_, w1, w2, act_input_1, act_input_2)
622
+ ctx.use_gelu = use_gelu
623
+ ctx.is_even_nmk = IS_EVEN_MNK
624
+ ctx.x_shape = x.shape
625
+
626
+ out = out if x.ndim == 2 else out.reshape(*x.shape[:-1], N)
627
+
628
+ return out
629
+
630
+ @staticmethod
631
+ @custom_bwd
632
+ def backward(ctx, dout):
633
+ BLOCK_M = 64
634
+ BLOCK_N = 64
635
+ BLOCK_K = 64
636
+ GROUP_M = 8
637
+
638
+ x_, w1, w2, act_input_1, act_input_2 = ctx.saved_tensors
639
+
640
+ M, K = x_.shape
641
+ N, K = w1.shape
642
+
643
+ tl_dtype = to_tl_dtype(x_.dtype)
644
+
645
+ '''
646
+ din = torch.empty_like(x_)
647
+ dw1 = torch.empty_like(w1)
648
+ dw2 = torch.empty_like(w2)
649
+
650
+ dout_ = dout if dout.ndim == 2 else dout.flatten(0, -2)
651
+
652
+ y1_grad = torch.empty_like(dout_)
653
+ y2_grad = torch.empty_like(dout_)
654
+
655
+ grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
656
+ gated_matmul_bwd_ygrad[grid](
657
+ dout_,
658
+ y1_grad, y2_grad,
659
+ act_input_1, act_input_2,
660
+ M, N,
661
+ dout_.stride(0),
662
+ # Meta-parameters
663
+ tl_dtype,
664
+ BLOCK_M, BLOCK_N,
665
+ ctx.use_gelu,
666
+ ctx.is_even_nmk)
667
+
668
+ grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(K, BLOCK_K),)
669
+ gated_matmul_bwd_input[grid](
670
+ w1, w2,
671
+ y1_grad, y2_grad,
672
+ din,
673
+ M, N, K,
674
+ dout_.stride(0), x_.stride(0),
675
+ w1.stride(0),
676
+ tl_dtype,
677
+ BLOCK_M, GROUP_M,
678
+ BLOCK_N, BLOCK_K,
679
+ ctx.is_even_nmk)
680
+
681
+ # reorder sizes
682
+ BLOCK_M = 64
683
+ BLOCK_N = 64
684
+ grid = (triton.cdiv(N, BLOCK_N) * triton.cdiv(K, BLOCK_K),)
685
+ gated_matmul_bwd_weights[grid](
686
+ x_,
687
+ y1_grad, y2_grad,
688
+ dw1, dw2,
689
+ M, N, K,
690
+ y1_grad.stride(0), x_.stride(0),
691
+ dw1.stride(0),
692
+ tl_dtype,
693
+ BLOCK_M, GROUP_M,
694
+ BLOCK_N, BLOCK_K,
695
+ ctx.is_even_nmk)
696
+
697
+ din = din if len(ctx.x_shape) == 2 else din.reshape(ctx.x_shape)
698
+ '''
699
+
700
+ dout_ = dout if dout.ndim == 2 else dout.flatten(0, -2)
701
+
702
+ y1_grad = torch.empty_like(dout_)
703
+ y2_grad = torch.empty_like(dout_)
704
+
705
+ grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
706
+ gated_matmul_bwd_ygrad[grid](
707
+ dout_,
708
+ y1_grad, y2_grad,
709
+ act_input_1, act_input_2,
710
+ M, N,
711
+ dout_.stride(0),
712
+ # Meta-parameters
713
+ tl_dtype,
714
+ BLOCK_M, BLOCK_N,
715
+ ctx.use_gelu,
716
+ ctx.is_even_nmk)
717
+
718
+ #y2_grad = torch.mul(gelu_torch(x_ @ w1.t()), dout_)
719
+ #y1_grad = torch.mul(gelu_grad_torch(x_ @ w1.t()) * (x_ @ w2.t()), dout_)
720
+
721
+ din = torch.matmul(y2_grad, w2) + torch.matmul(y1_grad, w1)
722
+ dw1 = torch.matmul(y1_grad.t(), x_)
723
+ dw2 = torch.matmul(y2_grad.t(), x_)
724
+
725
+ din = din if len(ctx.x_shape) == 2 else din.reshape(ctx.x_shape)
726
+
727
+ return din, dw1, dw2, None
728
+
729
+ gated_mlp = GatedMLP.apply
modeling_flash_t5.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import copy
6
+ import math
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ import torch.nn.functional as F
13
+
14
+ from transformers.modeling_utils import ModuleUtilsMixin
15
+ from transformers.modeling_outputs import ModelOutput, Seq2SeqModelOutput, BaseModelOutput
16
+ from transformers import PreTrainedModel
17
+
18
+ try:
19
+ from .rms_norm import fast_rms_layernorm
20
+ except ImportError:
21
+ fast_rms_layernorm = None
22
+
23
+ try:
24
+ from .cross_entropy_loss import fast_cross_entropy_loss
25
+ except ImportError:
26
+ fast_cross_entropy_loss = None
27
+
28
+ try:
29
+ from .flash_attention_v2_bias import attention as flash_attention_triton
30
+ except ImportError:
31
+ fast_cross_entropy_loss = None
32
+
33
+ try:
34
+ from .gated_mlp import gated_mlp
35
+ except ImportError:
36
+ gated_mlp = None
37
+
38
+ try:
39
+ #from flash_attn import flash_attn_kvpacked_func, flash_attn_func
40
+ from .fa2_compilable import flash_attn_kvpacked_func, flash_attn_func
41
+ except ImportError:
42
+ flash_attn_kvpacked_func, flash_attn_func = None, None
43
+
44
+ from .attn_ref import attn_ref
45
+
46
+ from .configuration_flash_t5 import FlashT5Config
47
+ from .positional_encoding import ALiBiPositionalEncoding, RelativePositionalEncoding, RotaryPositionalEncoding
48
+
49
+ @dataclass
50
+ class EncoderOutput(ModelOutput):
51
+ hidden_states: torch.FloatTensor = None
52
+ attention_mask: torch.FloatTensor = None
53
+
54
+ @dataclass
55
+ class Seq2SeqLMOutput(ModelOutput):
56
+ loss: torch.FloatTensor = None
57
+ logits: torch.FloatTensor = None
58
+ encoder_outputs: EncoderOutput = None
59
+
60
+
61
+ class FlashT5CrossEntropyLoss(nn.Module):
62
+ def __init__(self, z_loss_factor=0.0, label_smoothing=0.0, use_triton_crossentropy=False):
63
+
64
+ super().__init__()
65
+
66
+ if use_triton_crossentropy and fast_cross_entropy_loss is None:
67
+ raise ImportError("fast_cross_entropy_loss is not available")
68
+
69
+ self.use_triton_crossentropy = use_triton_crossentropy
70
+ self.z_loss_factor = z_loss_factor
71
+
72
+ self.cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
73
+
74
+ def compute_zloss(self, logits: torch.Tensor, z_loss: float):
75
+ logits_sum = torch.logsumexp(logits, dim=-1, keepdim=True)
76
+ log_z = torch.squeeze(logits_sum, axis=-1)
77
+ total_z_loss = z_loss * torch.square(log_z)
78
+ return total_z_loss.mean()
79
+
80
+ def forward(self, logits, labels):
81
+
82
+ if self.use_triton_crossentropy:
83
+ return fast_cross_entropy_loss(logits, labels, z_loss_factor=self.z_loss_factor)
84
+
85
+ # use standard method
86
+ batch, seq_len, d = logits.shape
87
+ logits_flatten = logits.float().view(batch*seq_len, d) # Must cast to float32 for numerical stability
88
+ labels_flatten = labels.view(-1)
89
+ loss = self.cross_entropy_loss(logits_flatten, labels_flatten)
90
+ z_loss = 0.0
91
+ if self.z_loss_factor != 0.0:
92
+ z_loss = self.compute_zloss(logits_flatten[labels_flatten != -100],
93
+ z_loss=self.z_loss_factor)
94
+ return loss, z_loss
95
+
96
+ class FlashT5LayerNorm(nn.Module):
97
+ def __init__(self, hidden_size, eps=1e-6, use_triton_layernorm=False):
98
+ """
99
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
100
+ """
101
+ super().__init__()
102
+
103
+ if use_triton_layernorm and fast_rms_layernorm is None:
104
+ raise ImportError("fast_rms_layernorm is not available")
105
+
106
+ self.use_triton_layernorm = use_triton_layernorm
107
+ self.weight = nn.Parameter(torch.ones(hidden_size))
108
+ self.variance_epsilon = eps
109
+
110
+ def forward(self, hidden_states):
111
+
112
+ if self.use_triton_layernorm:
113
+ return fast_rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
114
+
115
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
116
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
117
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
118
+ # half-precision inputs is done in fp32
119
+
120
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
121
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
122
+
123
+ # convert into half-precision if necessary
124
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
125
+ hidden_states = hidden_states.to(self.weight.dtype)
126
+
127
+ return self.weight * hidden_states
128
+
129
+ class FlashT5DenseAct(nn.Module):
130
+ def __init__(self, config: FlashT5Config):
131
+ super().__init__()
132
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
133
+ self.dropout = nn.Dropout(config.dropout_rate)
134
+ self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU()
135
+
136
+ def forward(self, hidden_states):
137
+ hidden_states = self.wi(hidden_states)
138
+ hidden_states = self.act(hidden_states)
139
+ hidden_states = self.dropout(hidden_states)
140
+ if (
141
+ isinstance(self.wo.weight, torch.Tensor)
142
+ and hidden_states.dtype != self.wo.weight.dtype
143
+ and self.wo.weight.dtype != torch.int8
144
+ ):
145
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
146
+
147
+ return hidden_states
148
+
149
+ class FlashT5DenseGatedAct(nn.Module):
150
+ def __init__(self, config: FlashT5Config):
151
+ super().__init__()
152
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
153
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
154
+ self.dropout = nn.Dropout(config.dropout_rate)
155
+ self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU()
156
+
157
+ self.use_triton_gated_mlp = config.use_triton_gated_mlp
158
+ if self.use_triton_gated_mlp and gated_mlp is None:
159
+ raise ImportError("gated_mlp is not available")
160
+ self.use_gelu_act = config.use_gelu_act
161
+
162
+ def forward(self, hidden_states):
163
+
164
+ if self.use_triton_gated_mlp:
165
+ return gated_mlp(hidden_states, self.wi_0.weight, self.wi_1.weight, self.use_gelu_act)
166
+
167
+ hidden_act = self.act(self.wi_0(hidden_states))
168
+ hidden_linear = self.wi_1(hidden_states)
169
+ hidden_states = hidden_act * hidden_linear
170
+ hidden_states = self.dropout(hidden_states)
171
+
172
+ return hidden_states
173
+
174
+ class FlashT5LayerFF(nn.Module):
175
+ def __init__(self, config: FlashT5Config):
176
+ super().__init__()
177
+ if config.use_glu_mlp:
178
+ self.act = FlashT5DenseGatedAct(config)
179
+ else:
180
+ self.act = FlashT5DenseAct(config)
181
+
182
+ self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
183
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
184
+ self.dropout = nn.Dropout(config.dropout_rate)
185
+
186
+ def forward(self, hidden_states):
187
+ forwarded_states = self.layer_norm(hidden_states).type_as(hidden_states)
188
+ forwarded_states = self.act(forwarded_states)
189
+ forwarded_states = self.wo(forwarded_states)
190
+ hidden_states = hidden_states + self.dropout(forwarded_states)
191
+ return hidden_states
192
+
193
+
194
+ class FlashT5Attention(nn.Module, ModuleUtilsMixin):
195
+ def __init__(self, config: FlashT5Config, has_positional_encoding=False, is_causal=False):
196
+ super().__init__()
197
+ self.is_decoder = config.is_decoder
198
+ self.has_positional_encoding = has_positional_encoding
199
+ self.is_causal = is_causal
200
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
201
+ self.relative_attention_max_distance = config.relative_attention_max_distance
202
+ self.d_model = config.d_model
203
+ self.key_value_proj_dim = config.d_kv
204
+ self.n_heads = config.num_heads
205
+ self.p_dropout = config.attention_dropout_rate
206
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
207
+ self.use_flash_attention = config.use_flash_attention
208
+ self.position_encoding_type = config.position_encoding_type
209
+ self.max_sequence_length = config.max_sequence_length
210
+ self.softmax_scale = 1.0/math.sqrt(self.n_heads)
211
+ self.use_full_bias_size = config.use_full_bias_size
212
+
213
+ if self.use_flash_attention == "triton" and flash_attention_triton is None:
214
+ raise ImportError("flash_attention_triton is not available")
215
+ elif self.use_flash_attention == "fa2" and flash_attn_func is None:
216
+ raise ImportError("Flash Attention 2 is not available")
217
+
218
+ assert (self.p_dropout == 0.0) or (self.use_flash_attention != "triton"), "Triton attention does not support dropout"
219
+
220
+ self.pe_encoding = None
221
+ if self.position_encoding_type == "ALiBi" and has_positional_encoding:
222
+ # build alibi matrix with an upper bound on seq length
223
+ self.pe_encoding = ALiBiPositionalEncoding(self.max_sequence_length, self.n_heads, config.alibi_mode, config.use_randomized_position_encoding)
224
+ elif self.position_encoding_type == "t5" and has_positional_encoding:
225
+ self.pe_encoding = RelativePositionalEncoding(self.relative_attention_num_buckets, self.relative_attention_max_distance, self.n_heads, self.max_sequence_length, config.use_randomized_position_encoding)
226
+ elif self.position_encoding_type == "RoPE":
227
+ self.pe_encoding = RotaryPositionalEncoding(int(self.key_value_proj_dim * config.rotary_emb_fraction), self.max_sequence_length, config.rotary_base, config.rotary_interleaved, config.rotary_scale_base, config.use_randomized_position_encoding)
228
+
229
+ self.Wq = nn.Linear(self.d_model, self.inner_dim, bias=False)
230
+ self.Wk = nn.Linear(self.d_model, self.inner_dim, bias=False)
231
+ self.Wv = nn.Linear(self.d_model, self.inner_dim, bias=False)
232
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states,
237
+ mask=None,
238
+ key_value_states=None,
239
+ position_bias=None,
240
+ ):
241
+ """
242
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
243
+ """
244
+ # Input is (batch_size, seq_length, dim)
245
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
246
+ batch_size, seq_length = hidden_states.shape[:2]
247
+ key_length = seq_length if key_value_states is None else key_value_states.shape[1]
248
+ q = self.Wq(hidden_states)
249
+ if key_value_states is None:
250
+ k = self.Wk(hidden_states)
251
+ v = self.Wv(hidden_states)
252
+ else:
253
+ k = self.Wk(key_value_states)
254
+ v = self.Wv(key_value_states)
255
+
256
+ q = q.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim)
257
+ k = k.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim)
258
+ v = v.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim)
259
+
260
+ if position_bias is None and self.pe_encoding is not None:
261
+ q, k, v, position_bias = self.pe_encoding(q, k, v)
262
+
263
+ if position_bias is not None and self.use_full_bias_size and (self.use_flash_attention == "fa2" or self.use_flash_attention == "triton"):
264
+ position_bias = position_bias.expand(q.shape[0], q.shape[2], q.shape[1], k.shape[1]).contiguous()
265
+
266
+ if self.use_flash_attention == "fa2":
267
+ output = flash_attn_func(q, k, v, dropout_p=self.p_dropout, softmax_scale=self.softmax_scale, attn_bias=position_bias, causal=self.is_causal)
268
+ elif self.use_flash_attention == "triton":
269
+ q = q.permute(0, 2, 1, 3)
270
+ k = k.permute(0, 2, 1, 3)
271
+ v = v.permute(0, 2, 1, 3)
272
+ output = flash_attention_triton(q, k, v, position_bias, self.is_causal, self.softmax_scale)
273
+ output = output.permute(0, 2, 1, 3)
274
+ else: # use flash attention
275
+ q = q.permute(0, 2, 1, 3)
276
+ k = k.permute(0, 2, 1, 3)
277
+ v = v.permute(0, 2, 1, 3)
278
+ output = attn_ref(q, k, v, position_bias, dropout_p=self.p_dropout, sm_scale=self.softmax_scale, causal=self.is_causal)
279
+ output = output.permute(0, 2, 1, 3)
280
+
281
+ output = self.o(output.reshape(output.shape[0], output.shape[1], self.inner_dim))
282
+ return (output, position_bias)
283
+
284
+
285
+ class FlashT5LayerSelfAttention(nn.Module):
286
+ def __init__(self, config, has_positional_encoding=False):
287
+ super().__init__()
288
+ self.self_attention = FlashT5Attention(config, has_positional_encoding=has_positional_encoding, is_causal=config.is_decoder)
289
+ self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
290
+ self.dropout = nn.Dropout(config.dropout_rate)
291
+
292
+ def forward(
293
+ self,
294
+ hidden_states,
295
+ attention_mask=None,
296
+ position_bias=None,
297
+ ):
298
+ normed_hidden_states = self.layer_norm(hidden_states).type_as(hidden_states)
299
+ attention_output = self.self_attention(
300
+ normed_hidden_states,
301
+ mask=attention_mask,
302
+ position_bias=position_bias,
303
+ )
304
+ hidden_states = hidden_states + self.dropout(attention_output[0])
305
+ outputs = (hidden_states,) + attention_output[1:]
306
+ return outputs
307
+
308
+
309
+ class FlashT5LayerCrossAttention(nn.Module):
310
+ def __init__(self, config):
311
+ super().__init__()
312
+ self.cross_attention = FlashT5Attention(config, has_positional_encoding=False)
313
+ self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
314
+ self.dropout = nn.Dropout(config.dropout_rate)
315
+
316
+ def forward(
317
+ self,
318
+ hidden_states,
319
+ key_value_states,
320
+ attention_mask=None,
321
+ position_bias=None,
322
+ ):
323
+ normed_hidden_states = self.layer_norm(hidden_states)
324
+ attention_output = self.cross_attention(
325
+ normed_hidden_states,
326
+ mask=attention_mask,
327
+ key_value_states=key_value_states,
328
+ position_bias=position_bias,
329
+ )
330
+ layer_output = hidden_states + self.dropout(attention_output[0])
331
+ outputs = (layer_output,) + attention_output[1:]
332
+ return outputs
333
+
334
+
335
+ class FlashT5Block(nn.Module):
336
+ def __init__(self, config, has_positional_encoding=False):
337
+ super().__init__()
338
+ self.is_decoder = config.is_decoder
339
+
340
+ self.self_attention_layer = FlashT5LayerSelfAttention(config, has_positional_encoding=has_positional_encoding)
341
+
342
+ if self.is_decoder:
343
+ self.cross_attention_layer = FlashT5LayerCrossAttention(config)
344
+
345
+ self.ff_layer = FlashT5LayerFF(config)
346
+
347
+ def forward(
348
+ self,
349
+ hidden_states,
350
+ attention_mask=None,
351
+ position_bias=None,
352
+ encoder_hidden_states=None,
353
+ encoder_attention_mask=None,
354
+ encoder_decoder_position_bias=None,
355
+ ):
356
+ self_attention_outputs = self.self_attention_layer(
357
+ hidden_states,
358
+ attention_mask=attention_mask,
359
+ position_bias=position_bias,
360
+ )
361
+ hidden_states = self_attention_outputs[0]
362
+ attention_outputs = self_attention_outputs[1:] # Relative position weights
363
+
364
+ if self.is_decoder and encoder_hidden_states is not None:
365
+ cross_attention_outputs = self.cross_attention_layer(
366
+ hidden_states,
367
+ key_value_states=encoder_hidden_states,
368
+ attention_mask=encoder_attention_mask,
369
+ position_bias=encoder_decoder_position_bias,
370
+ )
371
+ hidden_states = cross_attention_outputs[0]
372
+
373
+ # Keep relative position weights
374
+ attention_outputs = attention_outputs + cross_attention_outputs[1:]
375
+
376
+ # Apply Feed Forward layer
377
+ hidden_states = self.ff_layer(hidden_states)
378
+
379
+ outputs = (hidden_states,) + attention_outputs
380
+ return outputs # hidden-states, (self-attention position bias), (cross-attention position bias)
381
+
382
+ class FlashT5Stack(nn.Module, ModuleUtilsMixin):
383
+ def __init__(self, config, embed_tokens):
384
+ super().__init__()
385
+ assert embed_tokens is not None
386
+
387
+ self.config = config
388
+ self.embed_tokens = embed_tokens
389
+ self.is_decoder = config.is_decoder
390
+ self.use_flash_attention = config.use_flash_attention
391
+
392
+ self.block = nn.ModuleList(
393
+ [FlashT5Block(config, has_positional_encoding=bool(i == 0)) for i in range(config.num_layers)]
394
+ )
395
+
396
+ self.final_layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm)
397
+ self.dropout = nn.Dropout(config.dropout_rate)
398
+
399
+ def forward(
400
+ self,
401
+ input_ids=None,
402
+ attention_mask=None,
403
+ encoder_hidden_states=None,
404
+ encoder_attention_mask=None,
405
+ inputs_embeds=None,
406
+ head_mask=None,
407
+ cross_attn_head_mask=None,
408
+ past_key_values=None,
409
+ use_cache=None,
410
+ output_attentions=None,
411
+ output_hidden_states=None,
412
+ return_dict=None) -> BaseModelOutput:
413
+ input_shape = input_ids.size()
414
+ batch_size, seq_length = input_shape
415
+
416
+ if inputs_embeds is None:
417
+ inputs_embeds = self.embed_tokens(input_ids)
418
+
419
+ if torch.is_autocast_enabled() and input_ids.device.type == 'cuda':
420
+ inputs_embeds = inputs_embeds.to(torch.get_autocast_gpu_dtype())
421
+
422
+ # Masking
423
+ if attention_mask is None:
424
+ attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device, dtype=torch.bool)
425
+
426
+ if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
427
+ encoder_seq_length = encoder_hidden_states.shape[1]
428
+ encoder_attention_mask = torch.ones(
429
+ batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.bool
430
+ )
431
+
432
+ position_bias = None
433
+ encoder_decoder_position_bias = None
434
+
435
+ hidden_states = self.dropout(inputs_embeds)
436
+
437
+ for _, layer_module in enumerate(self.block):
438
+ layer_outputs = layer_module(
439
+ hidden_states,
440
+ attention_mask=attention_mask,
441
+ position_bias=position_bias,
442
+ encoder_hidden_states=encoder_hidden_states,
443
+ encoder_attention_mask=encoder_attention_mask,
444
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
445
+ )
446
+
447
+ # We share the position biases between the layers - the first layer store them
448
+ position_bias = layer_outputs[1]
449
+ if self.is_decoder and encoder_hidden_states is not None:
450
+ encoder_decoder_position_bias = layer_outputs[2]
451
+
452
+ hidden_states = layer_outputs[0]
453
+
454
+ hidden_states = self.final_layer_norm(hidden_states).type_as(hidden_states)
455
+ hidden_states = self.dropout(hidden_states)
456
+
457
+ return BaseModelOutput(
458
+ last_hidden_state=hidden_states
459
+ )
460
+
461
+
462
+ class FlashT5PreTrainedModel(PreTrainedModel):
463
+ """
464
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
465
+ models.
466
+ """
467
+
468
+ config_class = FlashT5Config
469
+ base_model_prefix = "transformer"
470
+ is_parallelizable = False
471
+ supports_gradient_checkpointing = True
472
+ _no_split_modules = ["FlashT5Block"]
473
+ _keep_in_fp32_modules = []
474
+
475
+ def _init_weights(self, module):
476
+ factor = self.config.initializer_factor # Used for testing weights initialization
477
+ if isinstance(module, FlashT5LayerNorm):
478
+ module.weight.data.fill_(factor * 1.0)
479
+ elif isinstance(module, (FlashT5ForConditionalGeneration)):
480
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
481
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
482
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * self.config.d_model ** -0.5)
483
+ elif isinstance(module, FlashT5DenseGatedAct):
484
+ d_ff, d_model = module.wi_0.weight.data.size()
485
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
486
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
487
+ elif isinstance(module, FlashT5LayerFF):
488
+ d_ff, d_model = module.wo.weight.data.size()
489
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5))
490
+ elif isinstance(module, FlashT5Attention):
491
+ d_model = self.config.d_model
492
+ key_value_proj_dim = self.config.d_kv
493
+ n_heads = self.config.num_heads
494
+ module.Wq.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
495
+ module.Wk.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
496
+ module.Wv.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
497
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
498
+ if module.has_positional_encoding:
499
+ if hasattr(module.pe_encoding, "relative_attention_bias"):
500
+ module.pe_encoding.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
501
+
502
+ def _shift_right(self, input_ids):
503
+ decoder_start_token_id = self.config.decoder_start_token_id
504
+ pad_token_id = self.config.pad_token_id
505
+
506
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
507
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
508
+ shifted_input_ids[..., 0] = decoder_start_token_id
509
+
510
+ # replace possible -100 values in labels by `pad_token_id`
511
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
512
+
513
+ return shifted_input_ids
514
+
515
+
516
+ class FlashT5Model(FlashT5PreTrainedModel):
517
+ def __init__(self, config: FlashT5Config):
518
+ super().__init__(config)
519
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
520
+
521
+ encoder_config = copy.deepcopy(config)
522
+ encoder_config.is_decoder = False
523
+ encoder_config.use_cache = False
524
+ encoder_config.is_encoder_decoder = False
525
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
526
+
527
+ decoder_config = copy.deepcopy(config)
528
+ decoder_config.is_decoder = True
529
+ decoder_config.is_encoder_decoder = False
530
+ decoder_config.num_layers = config.num_decoder_layers
531
+ self.decoder = FlashT5Stack(decoder_config, self.shared)
532
+
533
+ # Initialize weights and apply final processing
534
+ self.post_init()
535
+
536
+ # Model parallel
537
+ self.model_parallel = False
538
+ self.device_map = None
539
+
540
+ def get_input_embeddings(self):
541
+ return self.shared
542
+
543
+ def set_input_embeddings(self, new_embeddings):
544
+ self.shared = new_embeddings
545
+ self.encoder.set_input_embeddings(new_embeddings)
546
+ self.decoder.set_input_embeddings(new_embeddings)
547
+
548
+ def get_encoder(self):
549
+ return self.encoder
550
+
551
+ def get_decoder(self):
552
+ return self.decoder
553
+
554
+ def forward(
555
+ self,
556
+ input_ids: Optional[torch.LongTensor] = None,
557
+ attention_mask: Optional[torch.FloatTensor] = None,
558
+ decoder_input_ids: Optional[torch.LongTensor] = None,
559
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
560
+ head_mask: Optional[torch.FloatTensor] = None,
561
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
562
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
563
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
564
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
565
+ inputs_embeds: Optional[torch.Tensor] = None,
566
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
567
+ use_cache: Optional[bool] = None,
568
+ output_attentions: Optional[bool] = None,
569
+ output_hidden_states: Optional[bool] = None,
570
+ return_dict: Optional[bool] = None,
571
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
572
+
573
+ # Encode if needed (training, first prediction pass)
574
+ if encoder_outputs is None:
575
+ encoder_outputs = self.encoder(
576
+ input_ids=input_ids,
577
+ attention_mask=attention_mask,
578
+ inputs_embeds=inputs_embeds
579
+ )
580
+
581
+ hidden_states = encoder_outputs[0]
582
+
583
+ # Decode
584
+ decoder_outputs = self.decoder(
585
+ input_ids=decoder_input_ids,
586
+ attention_mask=decoder_attention_mask,
587
+ inputs_embeds=decoder_inputs_embeds,
588
+ encoder_hidden_states=hidden_states,
589
+ encoder_attention_mask=attention_mask
590
+ )
591
+
592
+ return Seq2SeqModelOutput(
593
+ last_hidden_state=decoder_outputs.last_hidden_state,
594
+ decoder_hidden_states=decoder_outputs.hidden_states,
595
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
596
+ encoder_hidden_states=encoder_outputs.hidden_states,
597
+ )
598
+
599
+ class FlashT5ForConditionalGeneration(FlashT5PreTrainedModel):
600
+
601
+ def __init__(self, config: FlashT5Config):
602
+ super().__init__(config)
603
+ config.is_encoder_decoder = False
604
+ assert not config.tie_word_embeddings
605
+
606
+ self.config = config
607
+ self.model_dim = config.d_model
608
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
609
+
610
+ encoder_config = copy.deepcopy(config)
611
+ encoder_config.is_decoder = False
612
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
613
+
614
+ decoder_config = copy.deepcopy(config)
615
+ decoder_config.is_decoder = True
616
+ decoder_config.num_layers = config.num_decoder_layers
617
+ self.decoder = FlashT5Stack(decoder_config, self.shared)
618
+
619
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
620
+
621
+ self.loss_fct = FlashT5CrossEntropyLoss(z_loss_factor=config.z_loss,
622
+ label_smoothing=config.label_smoothing,
623
+ use_triton_crossentropy=config.use_triton_crossentropy)
624
+
625
+ # Initialize weights and apply final processing
626
+ self.post_init()
627
+
628
+ def prepare_inputs_for_generation(
629
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
630
+ ):
631
+ # do nothing
632
+ model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
633
+
634
+ return model_inputs
635
+
636
+ def get_input_embeddings(self):
637
+ return self.shared
638
+
639
+ def set_input_embeddings(self, value):
640
+ self.shared = value
641
+
642
+ def generate(
643
+ self,
644
+ input_ids: Optional[torch.LongTensor] = None,
645
+ attention_mask: Optional[torch.FloatTensor] = None,
646
+ max_length = 32,
647
+ **kwargs,
648
+ ) -> torch.LongTensor:
649
+ """
650
+ input_ids: B x L_encoder, int64
651
+ attention_mask: B x L_encoder, int64
652
+ 1 for tokens to attend to, 0 for tokens to ignore
653
+
654
+ Generation:
655
+ Starts with 0, ends with 1, padding is 0
656
+
657
+ # For 20 input/outputs, the diff between my implementation and HF is 9.8s vs 11.4s
658
+ """
659
+ B, _ = input_ids.size()
660
+ labels = torch.zeros(B, 1, dtype=torch.long, device=input_ids.device)
661
+ encoder_outputs = None
662
+
663
+ for _ in range(max_length):
664
+ out = self.forward(
665
+ input_ids=input_ids,
666
+ attention_mask=attention_mask,
667
+ decoder_input_ids=labels,
668
+ encoder_outputs=encoder_outputs,
669
+ )
670
+ encoder_outputs = out.encoder_outputs
671
+ top_labels = out.logits[:, -1].argmax(-1).unsqueeze(-1)
672
+ labels = torch.cat([labels, top_labels], dim=-1)
673
+
674
+ if (labels == 1).sum(-1).clamp(min=0, max=1).sum().item() == B:
675
+ break
676
+
677
+ labels[:, -1] = 1
678
+
679
+ # Mask out the padding, i.e., all positions after the first 1 with 0
680
+ B, L = labels.size()
681
+ mask = torch.arange(L, device=labels.device).unsqueeze(0) <= (labels == 1).long().argmax(-1).unsqueeze(-1)
682
+ labels = labels.masked_fill(~mask, 0)
683
+
684
+ return labels
685
+
686
+ def forward(
687
+ self,
688
+ input_ids: Optional[torch.LongTensor] = None,
689
+ attention_mask: Optional[torch.FloatTensor] = None,
690
+ decoder_input_ids: Optional[torch.LongTensor] = None,
691
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
692
+ labels: Optional[torch.LongTensor] = None,
693
+ encoder_outputs = None,
694
+ ) -> Seq2SeqLMOutput:
695
+ """
696
+ input_ids: B x L_encoder, int64
697
+ attention_mask: B x L_encoder, int64
698
+ 1 for tokens to attend to, 0 for tokens to ignore
699
+ labels: B x L_decoder, int64
700
+ """
701
+ if encoder_outputs is None:
702
+ encoder_outputs = self.encoder(
703
+ input_ids=input_ids,
704
+ attention_mask=attention_mask,
705
+ )
706
+
707
+ hidden_states = encoder_outputs.hidden_states
708
+
709
+ if labels is not None and decoder_input_ids is None:
710
+ decoder_input_ids = self._shift_right(labels)
711
+
712
+ decoder_outputs = self.decoder(
713
+ input_ids=decoder_input_ids,
714
+ attention_mask=decoder_attention_mask,
715
+ encoder_hidden_states=hidden_states,
716
+ encoder_attention_mask=attention_mask,
717
+ )
718
+
719
+ sequence_output = decoder_outputs[0]
720
+ lm_logits = self.lm_head(sequence_output)
721
+
722
+ loss = None
723
+ if labels is not None:
724
+ loss, z_loss = self.loss_fct(lm_logits, labels)
725
+ loss += z_loss
726
+
727
+ return Seq2SeqLMOutput(
728
+ loss=loss,
729
+ logits=lm_logits,
730
+ encoder_outputs=encoder_outputs,
731
+ )
732
+
733
+
734
+
735
+ class FlashT5EncoderModel(FlashT5PreTrainedModel):
736
+ _tied_weights_keys = ["encoder.embed_tokens.weight"]
737
+
738
+ def __init__(self, config: FlashT5Config):
739
+ super().__init__(config)
740
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
741
+
742
+ encoder_config = copy.deepcopy(config)
743
+ encoder_config.use_cache = False
744
+ encoder_config.is_encoder_decoder = False
745
+ self.encoder = FlashT5Stack(encoder_config, self.shared)
746
+
747
+ # Initialize weights and apply final processing
748
+ self.post_init()
749
+
750
+ # Model parallel
751
+ self.model_parallel = False
752
+ self.device_map = None
753
+
754
+
755
+ def parallelize(self, device_map=None):
756
+ warnings.warn(
757
+ "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
758
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
759
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
760
+ " 'block.1': 1, ...}",
761
+ FutureWarning,
762
+ )
763
+ self.device_map = (
764
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
765
+ if device_map is None
766
+ else device_map
767
+ )
768
+ assert_device_map(self.device_map, len(self.encoder.block))
769
+ self.encoder.parallelize(self.device_map)
770
+ self.model_parallel = True
771
+
772
+ def deparallelize(self):
773
+ warnings.warn(
774
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
775
+ FutureWarning,
776
+ )
777
+ self.encoder.deparallelize()
778
+ self.encoder = self.encoder.to("cpu")
779
+ self.model_parallel = False
780
+ self.device_map = None
781
+ torch.cuda.empty_cache()
782
+
783
+ def get_input_embeddings(self):
784
+ return self.shared
785
+
786
+ def set_input_embeddings(self, new_embeddings):
787
+ self.shared = new_embeddings
788
+ self.encoder.set_input_embeddings(new_embeddings)
789
+
790
+ def get_encoder(self):
791
+ return self.encoder
792
+
793
+ def _prune_heads(self, heads_to_prune):
794
+ """
795
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
796
+ class PreTrainedModel
797
+ """
798
+ for layer, heads in heads_to_prune.items():
799
+ self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
800
+
801
+ def forward(
802
+ self,
803
+ input_ids: Optional[torch.LongTensor] = None,
804
+ attention_mask: Optional[torch.FloatTensor] = None,
805
+ head_mask: Optional[torch.FloatTensor] = None,
806
+ inputs_embeds: Optional[torch.FloatTensor] = None,
807
+ output_attentions: Optional[bool] = None,
808
+ output_hidden_states: Optional[bool] = None,
809
+ return_dict: Optional[bool] = None,
810
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
811
+ r"""
812
+ Returns:
813
+
814
+ Example:
815
+
816
+ ```python
817
+ >>> from transformers import AutoTokenizer, T5EncoderModel
818
+
819
+ >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
820
+ >>> model = T5EncoderModel.from_pretrained("t5-small")
821
+ >>> input_ids = tokenizer(
822
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
823
+ ... ).input_ids # Batch size 1
824
+ >>> outputs = model(input_ids=input_ids)
825
+ >>> last_hidden_states = outputs.last_hidden_state
826
+ ```"""
827
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
828
+
829
+ encoder_outputs = self.encoder(
830
+ input_ids=input_ids,
831
+ attention_mask=attention_mask,
832
+ inputs_embeds=inputs_embeds,
833
+ head_mask=head_mask,
834
+ output_attentions=output_attentions,
835
+ output_hidden_states=output_hidden_states,
836
+ return_dict=return_dict,
837
+ )
838
+
839
+ return encoder_outputs
positional_encoding.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange, repeat
5
+
6
+ from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_func, apply_rotary_emb_kv_
7
+
8
+ class RelativePositionalEncoding(nn.Module):
9
+
10
+ def __init__(self, relative_attention_num_buckets, relative_attention_max_distance, n_heads, max_sequence_length, bidirectional=True, randomized_position=False):
11
+
12
+ super().__init__()
13
+
14
+ self.relative_attention_num_buckets = relative_attention_num_buckets
15
+ self.relative_attention_max_distance = relative_attention_max_distance
16
+ self.n_heads = n_heads
17
+ self.max_sequence_length = max_sequence_length
18
+ self.bidirectional = bidirectional
19
+ self.randomized_position = randomized_position
20
+
21
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
22
+
23
+ @staticmethod
24
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
25
+ """
26
+ Adapted from Mesh Tensorflow:
27
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
28
+
29
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
30
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
31
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
32
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
33
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
34
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
35
+
36
+ Args:
37
+ relative_position: an int32 Tensor
38
+ bidirectional: a boolean - whether the attention is bidirectional
39
+ num_buckets: an integer
40
+ max_distance: an integer
41
+
42
+ Returns:
43
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
44
+ """
45
+ relative_buckets = 0
46
+ if bidirectional:
47
+ num_buckets //= 2
48
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
49
+ relative_position = torch.abs(relative_position)
50
+ else:
51
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
52
+ # now relative_position is in the range [0, inf)
53
+
54
+ # half of the buckets are for exact increments in positions
55
+ max_exact = num_buckets // 2
56
+ is_small = relative_position < max_exact
57
+
58
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
59
+ relative_position_if_large = max_exact + (
60
+ torch.log(relative_position.float() / max_exact)
61
+ / math.log(max_distance / max_exact)
62
+ * (num_buckets - max_exact)
63
+ ).to(torch.long)
64
+ relative_position_if_large = torch.min(
65
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
66
+ )
67
+
68
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
69
+ return relative_buckets
70
+
71
+ def compute_bias(self, query_length, key_length, device=None):
72
+ """Compute binned relative position bias"""
73
+ if device is None:
74
+ device = self.relative_attention_bias.weight.device
75
+
76
+ if self.randomized_position:
77
+ context_position = torch.arange(self.max_sequence_length, dtype=torch.long, device=device)
78
+ context_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:query_length])
79
+ context_indices_rand[0] = 0 # root the first element of the sequence
80
+ context_position = context_position[context_indices_rand][:, None]
81
+
82
+ memory_position = torch.arange(self.max_sequence_length, dtype=torch.long, device=device)
83
+ memory_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:key_length])
84
+ memory_indices_rand[0] = 0 # root the first element of the sequence
85
+ memory_position = memory_position[memory_indices_rand][None, :]
86
+ else:
87
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
88
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
89
+
90
+ relative_position = memory_position - context_position # shape (query_length, key_length)
91
+
92
+ relative_position_bucket = self._relative_position_bucket(
93
+ relative_position, # shape (query_length, key_length)
94
+ bidirectional=self.bidirectional,
95
+ num_buckets=self.relative_attention_num_buckets,
96
+ max_distance=self.relative_attention_max_distance,
97
+ )
98
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
99
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
100
+ return values
101
+
102
+ def forward(self, q, k=None, v=None):
103
+
104
+ query_length = q.shape[1]
105
+ key_length = k.shape[1] if k is not None else query_length
106
+ bias = self.compute_bias(query_length, key_length, device=q.device).contiguous().to(q.dtype)
107
+
108
+ return q, k, v, bias
109
+
110
+
111
+ class ALiBiPositionalEncoding(nn.Module):
112
+
113
+ def __init__(self, max_sequence_length, num_heads, mode='symetric', randomized_position=False):
114
+
115
+ super().__init__()
116
+
117
+ self.max_sequence_length = max_sequence_length
118
+ self.num_heads = num_heads
119
+ self.mode = mode
120
+ self.randomized_position = randomized_position
121
+
122
+ self.alibi_bias = self.build_alibi_bias_matrix(num_heads, max_sequence_length, mode)
123
+
124
+ @staticmethod
125
+ def fill_with_neg_inf(t):
126
+ """FP16-compatible function that fills a tensor with -inf."""
127
+ return t.float().fill_(float("-inf")).type_as(t)
128
+
129
+ def get_slopes(self, n):
130
+
131
+ def get_slopes_power_of_2(n):
132
+ start = (2**(-2**-(math.log2(n)-3)))
133
+ ratio = start
134
+ return [start*ratio**i for i in range(n)]
135
+
136
+ if math.log2(n).is_integer():
137
+ return get_slopes_power_of_2(n) #In the paper, we only train models that have 2^a heads for some a. This function has
138
+ else: #some good properties that only occur when the input is a power of 2. To maintain that even
139
+ closest_power_of_2 = 2**math.floor(math.log2(n)) #when the number of heads is not a power of 2, we use this workaround.
140
+ return get_slopes_power_of_2(closest_power_of_2) + self.get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
141
+
142
+ def build_symetric_alibi_bias_matrix(self, num_heads, maxpos):
143
+
144
+ context_position = torch.arange(maxpos)[:, None]
145
+ memory_position = torch.arange(maxpos)[None, :]
146
+
147
+ relative_position = memory_position - context_position
148
+ relative_position = torch.abs(relative_position).unsqueeze(0).expand(num_heads, -1,-1)
149
+
150
+ slopes = torch.Tensor(self.get_slopes(num_heads)) * -1
151
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position
152
+ return alibi.view(1, num_heads, maxpos, maxpos)
153
+
154
+ def build_asymetric_alibi_bias_matrix(self, num_heads, maxpos):
155
+ _future_mask_right = torch.triu(self.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1).unsqueeze(0).repeat(num_heads // 2, 1, 1)
156
+ _future_mask_left = torch.tril(self.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), -1).unsqueeze(0).repeat(num_heads // 2, 1, 1)
157
+
158
+ nonsym_mask = torch.cat((_future_mask_right, _future_mask_left), dim = 0).unsqueeze(0)
159
+ slopes = torch.Tensor(self.get_slopes(num_heads // 2)) * -1
160
+
161
+ context_position = torch.arange(maxpos)[:, None]
162
+ memory_position = torch.arange(maxpos)[None, :]
163
+
164
+ relative_position = memory_position - context_position
165
+ relative_position = torch.abs(relative_position).unsqueeze(0).expand(num_heads // 2, -1,-1)
166
+
167
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position
168
+ alibi = alibi.view(1, num_heads // 2, maxpos, maxpos)
169
+ alibi = alibi.repeat(1, 2, 1, 1)
170
+
171
+ return alibi.view(1, num_heads, maxpos, maxpos) + nonsym_mask.view(1, num_heads, maxpos, maxpos)
172
+
173
+
174
+ def build_alibi_bias_matrix(self, num_heads, maxpos, mode='symetric'):
175
+ if mode == 'symetric':
176
+ return self.build_symetric_alibi_bias_matrix(num_heads, maxpos)
177
+ elif mode == 'asymetric':
178
+ return self.build_asymetric_alibi_bias_matrix(num_heads, maxpos)
179
+ else:
180
+ raise ValueError("ALiBi mode " + mode + " is not implemented.")
181
+
182
+ def forward(self, q, k=None, v=None):
183
+
184
+ query_length = q.shape[1]
185
+ key_length = k.shape[1] if k is not None else query_length
186
+ assert (self.alibi_bias.shape[1] < query_length) & (self.alibi_bias.shape[1] < key_length), "Sequence length larger than allowed alibi bound"
187
+
188
+ if self.randomized_position:
189
+ query_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:query_length])
190
+ key_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:key_length])
191
+
192
+ # ground sequences
193
+ query_indices_rand[0] = 0
194
+ key_indices_rand[0] = 0
195
+
196
+ bias = self.alibi_bias[:, :, query_indices_rand, key_indices_rand].to(q.device)
197
+
198
+ else:
199
+ bias = self.alibi_bias[:, :, :query_length, :key_length].to(q.device)
200
+
201
+ return q, k, v, bias.to(q.dtype).contiguous()
202
+
203
+ class RotaryPositionalEncoding(nn.Module):
204
+
205
+ def __init__(self, dim,
206
+ max_sequence_length,
207
+ base=10000.0,
208
+ interleaved=False,
209
+ scale_base=None,
210
+ randomized_position=False):
211
+
212
+ super().__init__()
213
+
214
+ self.max_sequence_length = max_sequence_length
215
+ self.randomized_position = randomized_position
216
+
217
+ self.dim = dim
218
+ self.base = base
219
+ self.interleaved = interleaved
220
+ self.scale_base = scale_base
221
+
222
+ inv_freq = self._compute_inv_freq()
223
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
224
+
225
+ scale = (
226
+ (torch.arange(0, dim, 2, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
227
+ if scale_base is not None
228
+ else None
229
+ )
230
+ self.register_buffer("scale", scale, persistent=False)
231
+
232
+ self._cos_cached = None
233
+ self._sin_cached = None
234
+ self._cos_k_cached = None
235
+ self._sin_k_cached = None
236
+
237
+ def _compute_inv_freq(self, device=None):
238
+ return 1.0 / (
239
+ self.base
240
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
241
+ )
242
+
243
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
244
+ # Reset the tables if the sequence length has changed,
245
+ # if we're on a new device (possibly due to tracing for instance),
246
+ # or if we're switching from inference mode to training
247
+ if (
248
+ self._cos_cached is None
249
+ or self._cos_cached.device != device
250
+ or self._cos_cached.dtype != dtype
251
+ or (self.training and self._cos_cached.is_inference())
252
+ ):
253
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
254
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
255
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
256
+ inv_freq = self._compute_inv_freq(device=device)
257
+
258
+ # Don't do einsum, it converts fp32 to fp16 under AMP
259
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
260
+ t = torch.arange(seqlen, device=device, dtype=dtype)
261
+ freqs = torch.outer(t, inv_freq)
262
+ if self.scale is None:
263
+ self._cos_cached = torch.cos(freqs).to(dtype)
264
+ self._sin_cached = torch.sin(freqs).to(dtype)
265
+ self._cos_k_cached = None
266
+ self._sin_k_cached = None
267
+ else:
268
+ power = (
269
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
270
+ - seqlen // 2
271
+ ) / self.scale_base
272
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
273
+ # We want the multiplication by scale to happen in fp32
274
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
275
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
276
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
277
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
278
+
279
+ def forward(self, q, k=None, v=None):
280
+
281
+ if self._cos_cached is None:
282
+ self._update_cos_sin_cache(self.max_sequence_length, device=q.device, dtype=q.dtype)
283
+
284
+ if k is None and v is None:
285
+ q = apply_rotary_emb_qkv_(
286
+ q,
287
+ self._cos_cached,
288
+ self._sin_cached,
289
+ self._cos_k_cached,
290
+ self._sin_k_cached,
291
+ interleaved=self.interleaved,
292
+ seqlen_offsets=0
293
+ )
294
+ elif v is None and k is not None:
295
+ q = apply_rotary_emb_func(
296
+ q,
297
+ self._cos_cached,
298
+ self._sin_cached,
299
+ interleaved=self.interleaved,
300
+ inplace=True,
301
+ seqlen_offsets=0
302
+ )
303
+
304
+ k = apply_rotary_emb_kv_(
305
+ k,
306
+ self._cos_cached if self._cos_k_cached is None else self._cos_k_cached,
307
+ self._sin_cached if self._sin_k_cached is None else self._sin_k_cached,
308
+ interleaved=self.interleaved,
309
+ seqlen_offsets=0,
310
+ )
311
+ else:
312
+ q = apply_rotary_emb_func(
313
+ q,
314
+ self._cos_cached,
315
+ self._sin_cached,
316
+ interleaved=self.interleaved,
317
+ inplace=True,
318
+ seqlen_offsets=0
319
+ )
320
+
321
+ k = apply_rotary_emb_func(
322
+ k,
323
+ self._cos_cached if self._cos_k_cached is None else self._cos_k_cached,
324
+ self._sin_cached if self._sin_k_cached is None else self._sin_k_cached,
325
+ interleaved=self.interleaved,
326
+ seqlen_offsets=0,
327
+ )
328
+
329
+ v = apply_rotary_emb_func(
330
+ v,
331
+ self._cos_cached if self._cos_k_cached is None else self._cos_k_cached,
332
+ self._sin_cached if self._sin_k_cached is None else self._sin_k_cached,
333
+ interleaved=self.interleaved,
334
+ seqlen_offsets=0,
335
+ )
336
+
337
+ return q, k, v, None
rms_norm.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
+ # Copyright 2024 CATIE. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # Modifications to the orignal file
17
+ # - add weights gradients
18
+ # - remove the mask if size is a power of 2
19
+ # - support for torch.compile
20
+
21
+ import triton
22
+ import triton.language as tl
23
+ import torch
24
+
25
+
26
+ MAX_FUSED_SIZE = 65536
27
+ next_power_of_2 = triton.next_power_of_2
28
+
29
+ def calculate_settings(n):
30
+ BLOCK_SIZE = next_power_of_2(n)
31
+ if BLOCK_SIZE > MAX_FUSED_SIZE:
32
+ raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
33
+ f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
34
+ num_warps = 4
35
+ if BLOCK_SIZE >= 32768: num_warps = 32
36
+ elif BLOCK_SIZE >= 8192: num_warps = 16
37
+ elif BLOCK_SIZE >= 2048: num_warps = 8
38
+ return BLOCK_SIZE, num_warps
39
+
40
+
41
+ @triton.jit
42
+ def _rms_layernorm_forward(
43
+ Y, Y_row_stride,
44
+ X, X_row_stride,
45
+ W, W_row_stride,
46
+ r, r_row_stride,
47
+ n_cols, eps,
48
+ BLOCK_SIZE : tl.constexpr,
49
+ IS_EVEN_X: tl.constexpr
50
+ ):
51
+ """
52
+ Fast RMS Layernorm kernel
53
+ Inspiration from a Triton tutorial:
54
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
55
+ """
56
+ row_idx = tl.program_id(0)
57
+ col_offsets = tl.arange(0, BLOCK_SIZE)
58
+ mask = col_offsets < n_cols
59
+
60
+ Y += row_idx * Y_row_stride
61
+ X += row_idx * X_row_stride
62
+ r += row_idx * r_row_stride
63
+
64
+ if IS_EVEN_X:
65
+ X_row = tl.load(X + col_offsets).to(tl.float32)
66
+ W_row = tl.load(W + col_offsets)
67
+ else:
68
+ X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
69
+ W_row = tl.load(W + col_offsets, mask=mask, other=0)
70
+
71
+ row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
72
+ inv_var = tl.math.rsqrt(row_var + eps)
73
+ tl.store(r, inv_var)
74
+ normed = X_row * inv_var
75
+ normed = normed.to(W_row.dtype) # Exact copy from HF
76
+ output = normed * W_row
77
+
78
+ if IS_EVEN_X:
79
+ tl.store(Y + col_offsets, output)
80
+ else:
81
+ tl.store(Y + col_offsets, output, mask=mask)
82
+
83
+ @triton.jit
84
+ def _rms_layernorm_backward(
85
+ dY, dY_row_stride,
86
+ X, X_row_stride,
87
+ W, W_row_stride,
88
+ r, r_row_stride,
89
+ dW, dW_row_stride,
90
+ dX, dX_row_stride,
91
+ n_cols, eps,
92
+ BLOCK_SIZE : tl.constexpr,
93
+ IS_EVEN_X: tl.constexpr
94
+ ):
95
+ """
96
+ Fast RMS Layernorm kernel for the backward pass
97
+ Inspiration from a Triton tutorial:
98
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
99
+ """
100
+ row_idx = tl.program_id(0)
101
+ col_offsets = tl.arange(0, BLOCK_SIZE)
102
+ mask = col_offsets < n_cols
103
+
104
+ dY += row_idx * dY_row_stride
105
+ X += row_idx * X_row_stride
106
+ r += row_idx * r_row_stride
107
+ dW += row_idx * dW_row_stride
108
+ dX += row_idx * dX_row_stride
109
+
110
+ if IS_EVEN_X:
111
+ dY_row = tl.load(dY + col_offsets).to(tl.float32)
112
+ X_row = tl.load(X + col_offsets).to(tl.float32)
113
+ W_row = tl.load(W + col_offsets).to(tl.float32)
114
+ else:
115
+ dY_row = tl.load(dY + col_offsets, mask=mask, other=0).to(tl.float32)
116
+ X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
117
+ W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)
118
+
119
+ # Get saved row variance
120
+ inv_var = tl.load(r).to(tl.float32)
121
+ normed = X_row * inv_var
122
+ dW_row = dY_row * normed
123
+
124
+ dY_W = dY_row * W_row
125
+ rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
126
+ output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
127
+
128
+ if IS_EVEN_X:
129
+ tl.store(dW + col_offsets, dW_row)
130
+ tl.store(dX + col_offsets, output)
131
+ else:
132
+ tl.store(dW + col_offsets, dW_row, mask=mask)
133
+ tl.store(dX + col_offsets, output, mask=mask)
134
+
135
+
136
+ # Wrapper for triton kernel for torch.compile - should be unecessary for PyTorch 2.3 ?
137
+ torch.library.define("flasht5::rmsnorm_triton_fwd", "(Tensor X, Tensor W, float eps, int n_cols, int n_rows, int BLOCK_SIZE, int num_warps) -> (Tensor, Tensor)")
138
+
139
+ @torch.library.impl("flasht5::rmsnorm_triton_fwd", "default")
140
+ def rmsnorm_triton_fwd(X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps):
141
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device="cuda")
142
+ r = torch.empty(n_rows, dtype=torch.float32, device="cuda")
143
+
144
+ _rms_layernorm_forward[(n_rows,)](
145
+ Y, Y.stride(0),
146
+ X, X.stride(0),
147
+ W, W.stride(0),
148
+ r, r.stride(0),
149
+ n_cols, eps,
150
+ BLOCK_SIZE=BLOCK_SIZE,
151
+ IS_EVEN_X=((n_cols % BLOCK_SIZE) == 0),
152
+ num_warps=num_warps
153
+ )
154
+
155
+ return Y, r
156
+
157
+
158
+ @torch.library.impl_abstract("flasht5::rmsnorm_triton_fwd", rmsnorm_triton_fwd)
159
+ def rmsnorm_triton_fwd_abstract(X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps):
160
+ Y = X.new_empty((n_rows, n_cols))
161
+ r = X.new_empty((n_rows))
162
+ return Y, r
163
+
164
+ torch.library.define("flasht5::rmsnorm_triton_bwd", "(Tensor dY, Tensor r, Tensor X, Tensor W, float eps, int n_cols, int n_rows, int BLOCK_SIZE, int num_warps) -> (Tensor, Tensor)")
165
+
166
+ @torch.library.impl("flasht5::rmsnorm_triton_bwd", "default")
167
+ def rmsnorm_triton_bwd(dY, r, X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps):
168
+
169
+ dX = torch.empty_like(dY)
170
+ dW = torch.empty_like(dY)
171
+
172
+ _rms_layernorm_backward[(n_rows,)](
173
+ dY, dY.stride(0),
174
+ X, X.stride(0),
175
+ W, 1,
176
+ r, 1,
177
+ dW, dW.stride(0),
178
+ dX, dX.stride(0),
179
+ n_cols, eps,
180
+ BLOCK_SIZE=BLOCK_SIZE,
181
+ IS_EVEN_X=((n_cols % BLOCK_SIZE) == 0),
182
+ num_warps=num_warps,
183
+ )
184
+
185
+ return dX, dW
186
+
187
+
188
+ @torch.library.impl_abstract("flasht5::rmsnorm_triton_bwd", rmsnorm_triton_bwd)
189
+ def rmsnorm_triton_bwd_abstract(dY, r, X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps):
190
+ return torch.empty_like(dY), torch.empty_like(dY)
191
+
192
+
193
+ class Fast_RMS_Layernorm(torch.autograd.Function):
194
+ @staticmethod
195
+ def forward(ctx, X, W, eps):
196
+ shape = X.shape
197
+ dim = shape[-1]
198
+ X = X.view(-1, dim)
199
+ n_rows, n_cols = X.shape
200
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
201
+
202
+ Y, r = torch.ops.flasht5.rmsnorm_triton_fwd(X, W, eps, n_cols, n_rows, BLOCK_SIZE, num_warps)
203
+
204
+ ctx.eps = eps
205
+ ctx.BLOCK_SIZE = BLOCK_SIZE
206
+ ctx.num_warps = num_warps
207
+ ctx.save_for_backward(X, W, r)
208
+ return Y.view(*shape)
209
+
210
+ @staticmethod
211
+ def backward(ctx, dY):
212
+ shape = dY.shape
213
+ dim = shape[-1]
214
+ dY = dY.view(-1, dim)
215
+ X, W, r = ctx.saved_tensors
216
+ n_rows, n_cols = dY.shape
217
+ dX = torch.empty_like(dY)
218
+ dW = torch.empty_like(dY)
219
+
220
+ dW, dX = torch.ops.flasht5.rmsnorm_triton_bwd(dY, r, X, W, ctx.eps, n_cols, n_rows, ctx.BLOCK_SIZE, ctx.num_warps)
221
+
222
+ dX = dX.view(*shape)
223
+ return dX, dW.sum(0), None
224
+
225
+ def fast_rms_layernorm(X, W, eps):
226
+ out = Fast_RMS_Layernorm.apply(X, W, eps)
227
+ return out