zaydzuhri commited on
Commit
66dbc1f
·
verified ·
1 Parent(s): 5379428

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/rectified_transformer_340M.json +19 -0
  2. configs/scaled_vanilla_transformer_120M.json +19 -0
  3. configs/transformer_1B.json +22 -0
  4. configs/transformer_340M.json +18 -0
  5. configs/vanilla_transformer_120M.json +19 -0
  6. fla/__init__.py +114 -0
  7. fla/__pycache__/__init__.cpython-312.pyc +0 -0
  8. fla/__pycache__/utils.cpython-312.pyc +0 -0
  9. fla/modules/feature_map.py +300 -0
  10. fla/modules/fused_bitlinear.py +638 -0
  11. fla/modules/fused_kl_div.py +323 -0
  12. fla/modules/fused_linear_cross_entropy.py +570 -0
  13. fla/modules/layernorm.py +1196 -0
  14. fla/modules/parallel.py +37 -0
  15. fla/modules/rotary.py +512 -0
  16. fla/ops/__init__.py +46 -0
  17. fla/ops/abc/__init__.py +7 -0
  18. fla/ops/attn/__pycache__/__init__.cpython-312.pyc +0 -0
  19. fla/ops/attn/__pycache__/naive_rectified.cpython-312.pyc +0 -0
  20. fla/ops/attn/__pycache__/parallel.cpython-312.pyc +0 -0
  21. fla/ops/attn/naive.py +28 -0
  22. fla/ops/attn/parallel.py +629 -0
  23. fla/ops/attn/parallel_rectified.py +643 -0
  24. fla/ops/attn/parallel_softpick.py +650 -0
  25. fla/ops/based/__pycache__/__init__.cpython-312.pyc +0 -0
  26. fla/ops/based/__pycache__/parallel.cpython-312.pyc +0 -0
  27. fla/ops/based/naive.py +72 -0
  28. fla/ops/based/parallel.py +410 -0
  29. fla/ops/common/__pycache__/__init__.cpython-312.pyc +0 -0
  30. fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc +0 -0
  31. fla/ops/common/__pycache__/chunk_o.cpython-312.pyc +0 -0
  32. fla/ops/common/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  33. fla/ops/common/__pycache__/utils.cpython-312.pyc +0 -0
  34. fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc +0 -0
  35. fla/ops/delta_rule/__pycache__/wy_fast.cpython-312.pyc +0 -0
  36. fla/ops/delta_rule/naive.py +120 -0
  37. fla/ops/forgetting_attn/__pycache__/parallel.cpython-312.pyc +0 -0
  38. fla/ops/forgetting_attn/parallel.py +708 -0
  39. fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc +0 -0
  40. fla/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  41. fla/ops/gated_delta_rule/__pycache__/wy_fast.cpython-312.pyc +0 -0
  42. fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +197 -0
  43. fla/ops/gla/__pycache__/__init__.cpython-312.pyc +0 -0
  44. fla/ops/gla/__pycache__/chunk.cpython-312.pyc +0 -0
  45. fla/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  46. fla/ops/gla/chunk.py +1486 -0
  47. fla/ops/gsa/__pycache__/chunk.cpython-312.pyc +0 -0
  48. fla/ops/hgrn/__pycache__/__init__.cpython-312.pyc +0 -0
  49. fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  50. fla/ops/lightning_attn/chunk.py +74 -0
configs/rectified_transformer_340M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_rectified_attn"
19
+ }
configs/scaled_vanilla_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "parallel_scaled_attn"
19
+ }
configs/transformer_1B.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "elementwise_affine": true,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "fuse_swiglu": true,
8
+ "hidden_act": "swish",
9
+ "hidden_ratio": 4,
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.006,
12
+ "intermediate_size": null,
13
+ "max_position_embeddings": 8192,
14
+ "model_type": "transformer",
15
+ "norm_eps": 1e-06,
16
+ "num_heads": 32,
17
+ "num_hidden_layers": 24,
18
+ "num_kv_heads": null,
19
+ "pad_token_id": 2,
20
+ "rope_theta": 10000.0,
21
+ "tie_word_embeddings": false
22
+ }
configs/transformer_340M.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": true,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.006,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "transformer",
12
+ "num_heads": 16,
13
+ "num_hidden_layers": 24,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": false,
16
+ "use_cache": true,
17
+ "vocab_size": 32000
18
+ }
configs/vanilla_transformer_120M.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "fuse_cross_entropy": true,
6
+ "fuse_norm": false,
7
+ "hidden_act": "swish",
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "max_position_embeddings": 4096,
11
+ "model_type": "transformer",
12
+ "num_heads": 12,
13
+ "num_hidden_layers": 14,
14
+ "norm_eps": 1e-06,
15
+ "tie_word_embeddings": true,
16
+ "use_cache": true,
17
+ "vocab_size": 32000,
18
+ "attn_impl": "naive_attn"
19
+ }
fla/__init__.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.layers import (
4
+ ABCAttention,
5
+ Attention,
6
+ BasedLinearAttention,
7
+ BitAttention,
8
+ DeltaNet,
9
+ GatedDeltaNet,
10
+ GatedDeltaProduct,
11
+ GatedLinearAttention,
12
+ GatedSlotAttention,
13
+ HGRN2Attention,
14
+ HGRNAttention,
15
+ LightNetAttention,
16
+ LinearAttention,
17
+ MultiScaleRetention,
18
+ NativeSparseAttention,
19
+ ReBasedLinearAttention,
20
+ RWKV6Attention,
21
+ RWKV7Attention,
22
+ )
23
+ from fla.models import (
24
+ ABCForCausalLM,
25
+ ABCModel,
26
+ BitNetForCausalLM,
27
+ BitNetModel,
28
+ DeltaNetForCausalLM,
29
+ DeltaNetModel,
30
+ GatedDeltaNetForCausalLM,
31
+ GatedDeltaNetModel,
32
+ GatedDeltaProductForCausalLM,
33
+ GatedDeltaProductModel,
34
+ GLAForCausalLM,
35
+ GLAModel,
36
+ GSAForCausalLM,
37
+ GSAModel,
38
+ HGRN2ForCausalLM,
39
+ HGRN2Model,
40
+ HGRNForCausalLM,
41
+ LightNetForCausalLM,
42
+ LightNetModel,
43
+ LinearAttentionForCausalLM,
44
+ LinearAttentionModel,
45
+ NSAForCausalLM,
46
+ NSAModel,
47
+ RetNetForCausalLM,
48
+ RetNetModel,
49
+ RWKV6ForCausalLM,
50
+ RWKV6Model,
51
+ RWKV7ForCausalLM,
52
+ RWKV7Model,
53
+ TransformerForCausalLM,
54
+ TransformerModel,
55
+ TransformerWithPruningForCausalLM,
56
+ TransformerWithPruningModel
57
+ )
58
+
59
+ __all__ = [
60
+ 'ABCAttention',
61
+ 'Attention',
62
+ 'BasedLinearAttention',
63
+ 'BitAttention',
64
+ 'DeltaNet',
65
+ 'GatedDeltaNet',
66
+ 'GatedDeltaProduct',
67
+ 'GatedLinearAttention',
68
+ 'GatedSlotAttention',
69
+ 'HGRNAttention',
70
+ 'HGRN2Attention',
71
+ 'LightNetAttention',
72
+ 'LinearAttention',
73
+ 'MultiScaleRetention',
74
+ 'NativeSparseAttention',
75
+ 'ReBasedLinearAttention',
76
+ 'RWKV6Attention',
77
+ 'RWKV7Attention',
78
+ 'ABCForCausalLM',
79
+ 'ABCModel',
80
+ 'BitNetForCausalLM',
81
+ 'BitNetModel',
82
+ 'DeltaNetForCausalLM',
83
+ 'DeltaNetModel',
84
+ 'GatedDeltaNetForCausalLM',
85
+ 'GatedDeltaNetModel',
86
+ 'GatedDeltaProductForCausalLM',
87
+ 'GatedDeltaProductModel',
88
+ 'GLAForCausalLM',
89
+ 'GLAModel',
90
+ 'GSAForCausalLM',
91
+ 'GSAModel',
92
+ 'HGRNForCausalLM',
93
+ 'HGRNModel',
94
+ 'HGRN2ForCausalLM',
95
+ 'HGRN2Model',
96
+ 'LightNetForCausalLM',
97
+ 'LightNetModel',
98
+ 'LinearAttentionForCausalLM',
99
+ 'LinearAttentionModel',
100
+ 'NSAForCausalLM',
101
+ 'NSAModel',
102
+ 'RetNetForCausalLM',
103
+ 'RetNetModel',
104
+ 'RWKV6ForCausalLM',
105
+ 'RWKV6Model',
106
+ 'RWKV7ForCausalLM',
107
+ 'RWKV7Model',
108
+ 'TransformerForCausalLM',
109
+ 'TransformerModel',
110
+ 'TransformerWithPruningForCausalLM',
111
+ 'TransformerWithPruningModel',
112
+ ]
113
+
114
+ __version__ = '0.1.2'
fla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.95 kB). View file
 
fla/__pycache__/utils.cpython-312.pyc ADDED
Binary file (12.3 kB). View file
 
fla/modules/feature_map.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from fla.modules.activations import fast_gelu_impl, sigmoid, sqrelu, swish
13
+ from fla.modules.layernorm import layer_norm
14
+ from fla.utils import checkpoint
15
+
16
+
17
+ @checkpoint
18
+ def flatten_diag_outer_product(x, y):
19
+ z = torch.einsum("...i,...j->...ij", x, y)
20
+ N = z.size(-1)
21
+ indicies = torch.triu_indices(N, N)
22
+ return z[..., indicies[0], indicies[1]]
23
+
24
+
25
+ @checkpoint
26
+ def flatten_diag_outer_product_off1(x, y):
27
+ z = torch.einsum("...i,...j->...ij", x, y)
28
+ N = z.size(-1)
29
+ indicies = torch.triu_indices(N, N, 1)
30
+ indices2 = torch.arange(0, N)
31
+ return z[..., indicies[0], indicies[1]], z[..., indices2, indices2]
32
+
33
+
34
+ def is_power_of_2(n):
35
+ return (n & (n - 1) == 0) and n != 0
36
+
37
+
38
+ class HedgehogFeatureMap(nn.Module):
39
+
40
+ r"""
41
+ Hedgehog feature map as introduced in
42
+ `The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry <https://arxiv.org/abs/2402.04347>`_
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ head_dim: int
48
+ ) -> HedgehogFeatureMap:
49
+ super().__init__()
50
+ # Trainable map
51
+ self.layer = nn.Linear(head_dim, head_dim)
52
+ self.init_weights_()
53
+
54
+ def init_weights_(self):
55
+ """Initialize trainable map as identity"""
56
+ with torch.no_grad():
57
+ identity = torch.eye(*self.layer.weight.shape[-2:], dtype=torch.float)
58
+ self.layer.weight.copy_(identity.to(self.layer.weight))
59
+ nn.init.zeros_(self.layer.bias)
60
+
61
+ def forward(self, x: torch.Tensor):
62
+ x = self.layer(x) # shape b, h, l, d
63
+ return torch.cat([2*x, -2*x], dim=-1).softmax(-1)
64
+
65
+
66
+ class T2RFeatureMap(nn.Module):
67
+
68
+ r"""
69
+ Simple linear mapping feature map as in
70
+ `Finetuning Pretrained Transformers into RNNs <https://arxiv.org/abs/2103.13076>`_
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ head_dim: int,
76
+ dot_dim: int = None,
77
+ bias: Optional[bool] = False
78
+ ) -> T2RFeatureMap:
79
+ super().__init__()
80
+ # Trainable map
81
+ if dot_dim is None:
82
+ dot_dim = head_dim
83
+
84
+ self.head_dim = head_dim
85
+ self.dot_dim = dot_dim
86
+ self.bias = bias
87
+
88
+ self.layer = nn.Linear(head_dim, dot_dim, bias=bias)
89
+
90
+ def __repr__(self) -> str:
91
+ return f"{self.__class__.__name__}(head_dim={self.head_dim}, dot_dim={self.dot_dim}, bias={self.bias})"
92
+
93
+ def forward(self, x: torch.Tensor):
94
+ return self.layer(x).relu()
95
+
96
+
97
+ class DPFPFeatureMap(nn.Module):
98
+
99
+ r"""
100
+ Deterministic Parameter-Free Projection (DPFP) feature map in
101
+ `Linear Transformers Are Secretly Fast Weight Programmers <https://arxiv.org/abs/2102.11174>`_
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ head_dim: int,
107
+ nu: int = 4
108
+ ) -> DPFPFeatureMap:
109
+ super().__init__()
110
+ self.nu = nu
111
+
112
+ def forward(self, x: torch.Tensor):
113
+ x = torch.cat([x.relu(), -x.relu()], dim=-1)
114
+ x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1)
115
+ x_repeat = torch.cat([x] * self.nu, dim=-1)
116
+ return x_repeat * x_rolled
117
+
118
+
119
+ class HadamardFeatureMap(nn.Module):
120
+ def __init__(
121
+ self,
122
+ head_dim: int
123
+ ) -> HadamardFeatureMap:
124
+ super().__init__()
125
+ # Trainable map
126
+ self.layer1 = nn.Linear(head_dim, head_dim)
127
+ self.layer2 = nn.Linear(head_dim, head_dim)
128
+
129
+ def forward(self, x: torch.Tensor):
130
+ return self.layer1(x) * self.layer2(x)
131
+
132
+
133
+ class LearnableOuterProductFeatureMap(nn.Module):
134
+ def __init__(
135
+ self,
136
+ head_dim: int,
137
+ feature_dim: int
138
+ ) -> LearnableOuterProductFeatureMap:
139
+ super().__init__()
140
+ # Trainable map
141
+ self.layer1 = nn.Linear(head_dim, feature_dim, bias=False)
142
+ self.layer2 = nn.Linear(head_dim, feature_dim, bias=False)
143
+ self.normalizer = feature_dim ** -0.5
144
+
145
+ def forward(self, x: torch.Tensor):
146
+ return flatten_diag_outer_product(self.layer1(x), self.layer2(x))
147
+
148
+
149
+ class LearnablePolySketchNonNegativeFeatureMap(nn.Module):
150
+
151
+ def __init__(
152
+ self,
153
+ head_dim: int,
154
+ sketch_size: Optional[int] = None,
155
+ degree: Optional[int] = 2
156
+ ) -> LearnablePolySketchNonNegativeFeatureMap:
157
+ super().__init__()
158
+
159
+ assert is_power_of_2(degree) and degree >= 2, f"The degree {degree} must be a power of 2"
160
+
161
+ self.head_dim = head_dim
162
+ self.sketch_size = sketch_size if sketch_size is not None else head_dim
163
+ self.degree = degree
164
+
165
+ self.gamma = nn.Parameter(torch.ones(head_dim))
166
+ self.beta = nn.Parameter(torch.zeros(head_dim))
167
+ # NOTE: the sketch layers defined here are quite different from the original paper
168
+ # currently we simply use linear layers without any non-linear activations
169
+ self.sketches1 = nn.ModuleList([
170
+ nn.Linear(head_dim, sketch_size, bias=False),
171
+ *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)]
172
+ ])
173
+ self.sketches2 = nn.ModuleList([
174
+ nn.Linear(head_dim, sketch_size, bias=False),
175
+ *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)]
176
+ ])
177
+
178
+ def forward(self, x: torch.Tensor):
179
+ # Section 2.1
180
+ x = layer_norm(x, self.gamma, self.beta)
181
+ # first map the input to sketch size with learnable parameters
182
+ x = self.sketches1[0](x) * self.sketches2[0](x) * self.head_dim ** -0.5
183
+ for i in range(1, int(math.log2(self.degree)) - 1):
184
+ x = self.sketches1[i](x) * self.sketches2[i](x) * self.head_dim ** -0.5
185
+ # do sketch mapping for log2(p) - 1 times in total
186
+ # do p=2 mapping to ensure non-negativity
187
+ return flatten_diag_outer_product(x, x)
188
+
189
+
190
+ class TaylorFeatureMap(nn.Module):
191
+ def __init__(
192
+ self,
193
+ head_dim: int
194
+ ) -> TaylorFeatureMap:
195
+ super().__init__()
196
+ self.head_dim = head_dim
197
+ self.r2 = math.sqrt(2)
198
+ self.rd = math.sqrt(self.head_dim)
199
+ self.rrd = math.sqrt(self.rd)
200
+
201
+ def forward(self, x: torch.Tensor):
202
+ x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
203
+ return torch.cat([torch.ones_like(x[..., 0:1]), x / self.rrd, x2_2 / (self.rd * self.r2), x2_1 / self.rd], dim=-1)
204
+
205
+
206
+ class RebasedFeatureMap(nn.Module):
207
+
208
+ def __init__(
209
+ self,
210
+ head_dim: int,
211
+ use_gamma: Optional[bool] = True,
212
+ use_beta: Optional[bool] = True,
213
+ normalize: Optional[bool] = True
214
+ ) -> RebasedFeatureMap:
215
+ super().__init__()
216
+
217
+ self.head_dim = head_dim
218
+ self.use_gamma = use_gamma
219
+ self.use_beta = use_beta
220
+ self.normalize = normalize
221
+
222
+ self.gamma = None
223
+ self.beta = None
224
+ if use_gamma:
225
+ self.gamma = nn.Parameter(torch.ones(head_dim))
226
+ if use_beta:
227
+ self.beta = nn.Parameter(torch.zeros(head_dim))
228
+
229
+ def forward(self, x: torch.Tensor, flatten: Optional[bool] = True):
230
+ if self.use_beta and self.use_gamma and self.normalize:
231
+ x = layer_norm(x, self.gamma, self.beta)
232
+ elif self.normalize:
233
+ x = F.layer_norm(x, (self.head_dim,), self.gamma, self.beta)
234
+ elif self.use_gamma and self.use_beta:
235
+ x = torch.addcmul(self.beta, x, self.gamma)
236
+ elif self.use_gamma:
237
+ x = x.mul(self.gamma)
238
+ else:
239
+ raise RuntimeError(f"Not supported combination of `use_gamma`, `use_beta` and `normalize`, "
240
+ f"which is currentlt set as (`{self.use_gamma}`, `{self.use_beta}`, `{self.normalize}`)")
241
+ if not flatten:
242
+ return x
243
+ x2_1, x2_2 = flatten_diag_outer_product_off1(x, x)
244
+ # rebased use learnable parameters to approximate any quadratic function
245
+ return torch.cat([x2_2 * self.head_dim ** -0.5, x2_1 * (2 / self.head_dim) ** 0.5], dim=-1)
246
+
247
+
248
+ class ReLUFeatureMap(nn.Module):
249
+
250
+ def __init__(
251
+ self,
252
+ ) -> ReLUFeatureMap:
253
+ super().__init__()
254
+
255
+ def forward(self, x: torch.Tensor):
256
+ return F.relu(x)
257
+
258
+
259
+ class SquaredReLUFeatureMap(nn.Module):
260
+
261
+ def __init__(
262
+ self,
263
+ ) -> SquaredReLUFeatureMap:
264
+ super().__init__()
265
+
266
+ def forward(self, x: torch.Tensor):
267
+ return sqrelu(x)
268
+
269
+
270
+ class GELUFeatureMap(nn.Module):
271
+
272
+ def __init__(
273
+ self,
274
+ ) -> GELUFeatureMap:
275
+ super().__init__()
276
+
277
+ def forward(self, x: torch.Tensor):
278
+ return fast_gelu_impl(x)
279
+
280
+
281
+ class SwishFeatureMap(nn.Module):
282
+
283
+ def __init__(
284
+ self,
285
+ ) -> SwishFeatureMap:
286
+ super().__init__()
287
+
288
+ def forward(self, x: torch.Tensor):
289
+ return swish(x)
290
+
291
+
292
+ class SigmoidFeatureMap(nn.Module):
293
+
294
+ def __init__(
295
+ self,
296
+ ) -> SigmoidFeatureMap:
297
+ super().__init__()
298
+
299
+ def forward(self, x: torch.Tensor):
300
+ return sigmoid(x)
fla/modules/fused_bitlinear.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # Implementations of BitLinear layer with fused LayerNorm and quantized Linear layer.
5
+ # [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764)
6
+ # [Scalable MatMul-free Language Modeling](https://arxiv.org/abs/2406.02528)
7
+
8
+ # Code adapted from https://github.com/ridgerchu/matmulfreellm/
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ from fla.modules.layernorm import RMSNorm
21
+ from fla.utils import get_multiprocessor_count, input_guard, require_version
22
+
23
+
24
+ def activation_quant(x):
25
+ """
26
+ Per-token quantization to 8 bits. No grouping is needed for quantization.
27
+
28
+ Args:
29
+ x: An activation tensor with shape [n, d].
30
+
31
+ Returns:
32
+ A quantized activation tensor with shape [n, d].
33
+ """
34
+ # Compute the scale factor
35
+ scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
36
+ # Quantize and then de-quantize the tensor
37
+ y = (x * scale).round().clamp_(-128, 127) / scale
38
+ return y
39
+
40
+
41
+ def weight_quant(w):
42
+ """
43
+ Per-tensor quantization to 1.58 bits. No grouping is needed for quantization.
44
+
45
+ Args:
46
+ w: A weight tensor with shape [d, k].
47
+
48
+ Returns:
49
+ A quantized weight tensor with shape [d, k].
50
+ """
51
+ # Compute the scale factor
52
+ scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
53
+ # Quantize and then de-quantize the tensor
54
+ u = (w * scale).round().clamp_(-1, 1) / scale
55
+ return u
56
+
57
+
58
+ @triton.autotune(
59
+ configs=[
60
+ triton.Config({}, num_warps=1),
61
+ triton.Config({}, num_warps=2),
62
+ triton.Config({}, num_warps=4),
63
+ triton.Config({}, num_warps=8),
64
+ triton.Config({}, num_warps=16),
65
+ triton.Config({}, num_warps=32),
66
+ ],
67
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
68
+ )
69
+ @triton.jit
70
+ def layer_norm_fwd_kernel_quant(
71
+ X, # pointer to the input
72
+ Y, # pointer to the output
73
+ W, # pointer to the weights
74
+ B, # pointer to the biases
75
+ RESIDUAL, # pointer to the residual
76
+ RESIDUAL_OUT, # pointer to the residual
77
+ Mean, # pointer to the mean
78
+ Rstd, # pointer to the 1/std
79
+ stride_x_row, # how much to increase the pointer when moving by 1 row
80
+ stride_y_row,
81
+ stride_res_row,
82
+ stride_res_out_row,
83
+ N, # number of columns in X
84
+ eps, # epsilon to avoid division by zero
85
+ IS_RMS_NORM: tl.constexpr,
86
+ BLOCK_N: tl.constexpr,
87
+ HAS_RESIDUAL: tl.constexpr,
88
+ STORE_RESIDUAL_OUT: tl.constexpr,
89
+ HAS_WEIGHT: tl.constexpr,
90
+ HAS_BIAS: tl.constexpr
91
+ ):
92
+ # Map the program id to the row of X and Y it should compute.
93
+ row = tl.program_id(0)
94
+ X += row * stride_x_row
95
+ Y += row * stride_y_row
96
+ if HAS_RESIDUAL:
97
+ RESIDUAL += row * stride_res_row
98
+ if STORE_RESIDUAL_OUT:
99
+ RESIDUAL_OUT += row * stride_res_out_row
100
+ # Compute mean and variance
101
+ cols = tl.arange(0, BLOCK_N)
102
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
103
+ if HAS_RESIDUAL:
104
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
105
+ x += residual
106
+ if STORE_RESIDUAL_OUT:
107
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
108
+ if not IS_RMS_NORM:
109
+ mean = tl.sum(x, axis=0) / N
110
+ tl.store(Mean + row, mean)
111
+ xbar = tl.where(cols < N, x - mean, 0.0)
112
+ var = tl.sum(xbar * xbar, axis=0) / N
113
+ else:
114
+ xbar = tl.where(cols < N, x, 0.0)
115
+ var = tl.sum(xbar * xbar, axis=0) / N
116
+ rstd = 1 / tl.sqrt(var + eps)
117
+ tl.store(Rstd + row, rstd)
118
+ # Normalize and apply linear transformation
119
+ mask = cols < N
120
+ if HAS_WEIGHT:
121
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
122
+ if HAS_BIAS:
123
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
124
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
125
+
126
+ y = x_hat * w if HAS_WEIGHT else x_hat
127
+ if HAS_BIAS:
128
+ y = y + b
129
+
130
+ # Aply quantization to the output
131
+ scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), 1e-5)
132
+ # Quantize and then de-quantize the tensor
133
+ y = tl.extra.cuda.libdevice.round(y * scale)
134
+ y = tl.maximum(tl.minimum(y, 127), -128) / scale
135
+
136
+ # Write output
137
+ tl.store(Y + cols, y, mask=mask)
138
+
139
+
140
+ def layer_norm_fwd_quant(
141
+ x: torch.Tensor,
142
+ weight: torch.Tensor,
143
+ bias: torch.Tensor,
144
+ eps: float,
145
+ residual: torch.Tensor = None,
146
+ out_dtype: torch.dtype = None,
147
+ residual_dtype: torch.dtype = None,
148
+ is_rms_norm: bool = False
149
+ ):
150
+ if residual is not None:
151
+ residual_dtype = residual.dtype
152
+ M, N = x.shape
153
+ # allocate output
154
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
155
+ if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
156
+ residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
157
+ else:
158
+ residual_out = None
159
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
160
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
161
+ # Less than 64KB per feature: enqueue fused kernel
162
+ MAX_FUSED_SIZE = 65536 // x.element_size()
163
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
164
+ if N > BLOCK_N:
165
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
166
+ # heuristics for number of warps
167
+ layer_norm_fwd_kernel_quant[(M,)](
168
+ x,
169
+ y,
170
+ weight,
171
+ bias,
172
+ residual,
173
+ residual_out,
174
+ mean,
175
+ rstd,
176
+ x.stride(0),
177
+ y.stride(0),
178
+ residual.stride(0) if residual is not None else 0,
179
+ residual_out.stride(0) if residual_out is not None else 0,
180
+ N,
181
+ eps,
182
+ is_rms_norm,
183
+ BLOCK_N,
184
+ residual is not None,
185
+ residual_out is not None,
186
+ weight is not None,
187
+ bias is not None,
188
+ )
189
+ # residual_out is None if residual is None and residual_dtype == input_dtype
190
+ return y, mean, rstd, residual_out if residual_out is not None else x
191
+
192
+
193
+ @triton.heuristics({
194
+ "RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None
195
+ })
196
+ @triton.autotune(
197
+ configs=[
198
+ triton.Config({}, num_warps=1),
199
+ triton.Config({}, num_warps=2),
200
+ triton.Config({}, num_warps=4),
201
+ triton.Config({}, num_warps=8),
202
+ triton.Config({}, num_warps=16),
203
+ triton.Config({}, num_warps=32),
204
+ ],
205
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
206
+ )
207
+ @triton.jit
208
+ def layer_norm_bwd_kernel(
209
+ X, # pointer to the input
210
+ W, # pointer to the weights
211
+ B, # pointer to the biases
212
+ Y, # pointer to the output to be recomputed
213
+ DY, # pointer to the output gradient
214
+ DX, # pointer to the input gradient
215
+ DW, # pointer to the partial sum of weights gradient
216
+ DB, # pointer to the partial sum of biases gradient
217
+ DRESIDUAL,
218
+ DRESIDUAL_IN,
219
+ Mean, # pointer to the mean
220
+ Rstd, # pointer to the 1/std
221
+ stride_x_row, # how much to increase the pointer when moving by 1 row
222
+ stride_y_row,
223
+ stride_dy_row,
224
+ stride_dx_row,
225
+ stride_dres_row,
226
+ stride_dres_in_row,
227
+ M, # number of rows in X
228
+ N, # number of columns in X
229
+ eps, # epsilon to avoid division by zero
230
+ rows_per_program,
231
+ IS_RMS_NORM: tl.constexpr,
232
+ BLOCK_N: tl.constexpr,
233
+ HAS_DRESIDUAL: tl.constexpr,
234
+ STORE_DRESIDUAL: tl.constexpr,
235
+ HAS_WEIGHT: tl.constexpr,
236
+ HAS_BIAS: tl.constexpr,
237
+ RECOMPUTE_OUTPUT: tl.constexpr,
238
+ ):
239
+ # Map the program id to the elements of X, DX, and DY it should compute.
240
+ row_block_id = tl.program_id(0)
241
+ row_start = row_block_id * rows_per_program
242
+ cols = tl.arange(0, BLOCK_N)
243
+ mask = cols < N
244
+ X += row_start * stride_x_row
245
+ if HAS_DRESIDUAL:
246
+ DRESIDUAL += row_start * stride_dres_row
247
+ if STORE_DRESIDUAL:
248
+ DRESIDUAL_IN += row_start * stride_dres_in_row
249
+ DY += row_start * stride_dy_row
250
+ DX += row_start * stride_dx_row
251
+ if RECOMPUTE_OUTPUT:
252
+ Y += row_start * stride_y_row
253
+ if HAS_WEIGHT:
254
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
255
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
256
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
257
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
258
+ if HAS_BIAS:
259
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
260
+ row_end = min((row_block_id + 1) * rows_per_program, M)
261
+ for row in range(row_start, row_end):
262
+ # Load data to SRAM
263
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
264
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
265
+ if not IS_RMS_NORM:
266
+ mean = tl.load(Mean + row)
267
+ rstd = tl.load(Rstd + row)
268
+ # Compute dx
269
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
270
+ xhat = tl.where(mask, xhat, 0.0)
271
+ if RECOMPUTE_OUTPUT:
272
+ y = xhat * w if HAS_WEIGHT else xhat
273
+ if HAS_BIAS:
274
+ y = y + b
275
+
276
+ # Aply quantization to the output
277
+ scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), 1e-5)
278
+ # Quantize and then de-quantize the tensor
279
+ y = tl.extra.cuda.libdevice.round(y * scale)
280
+ y = tl.maximum(tl.minimum(y, 127), -128) / scale
281
+
282
+ tl.store(Y + cols, y, mask=mask)
283
+ wdy = dy
284
+ if HAS_WEIGHT:
285
+ wdy = dy * w
286
+ dw += dy * xhat
287
+ if HAS_BIAS:
288
+ db += dy
289
+ if not IS_RMS_NORM:
290
+ c1 = tl.sum(xhat * wdy, axis=0) / N
291
+ c2 = tl.sum(wdy, axis=0) / N
292
+ dx = (wdy - (xhat * c1 + c2)) * rstd
293
+ else:
294
+ c1 = tl.sum(xhat * wdy, axis=0) / N
295
+ dx = (wdy - xhat * c1) * rstd
296
+ if HAS_DRESIDUAL:
297
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
298
+ dx += dres
299
+ # Write dx
300
+ if STORE_DRESIDUAL:
301
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
302
+ tl.store(DX + cols, dx, mask=mask)
303
+
304
+ X += stride_x_row
305
+ if HAS_DRESIDUAL:
306
+ DRESIDUAL += stride_dres_row
307
+ if STORE_DRESIDUAL:
308
+ DRESIDUAL_IN += stride_dres_in_row
309
+ if RECOMPUTE_OUTPUT:
310
+ Y += stride_y_row
311
+ DY += stride_dy_row
312
+ DX += stride_dx_row
313
+ if HAS_WEIGHT:
314
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
315
+ if HAS_BIAS:
316
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
317
+
318
+
319
+ def layer_norm_bwd(
320
+ dy: torch.Tensor,
321
+ x: torch.Tensor,
322
+ weight: torch.Tensor,
323
+ bias: torch.Tensor,
324
+ eps: float,
325
+ mean: torch.Tensor,
326
+ rstd: torch.Tensor,
327
+ dresidual: torch.Tensor = None,
328
+ has_residual: bool = False,
329
+ is_rms_norm: bool = False,
330
+ x_dtype: torch.dtype = None,
331
+ recompute_output: bool = False,
332
+ ):
333
+ M, N = x.shape
334
+ # allocate output
335
+ dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)
336
+ dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
337
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
338
+
339
+ # Less than 64KB per feature: enqueue fused kernel
340
+ MAX_FUSED_SIZE = 65536 // x.element_size()
341
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
342
+ if N > BLOCK_N:
343
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
344
+ sm_count = get_multiprocessor_count(x.device.index)
345
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) if weight is not None else None
346
+ _db = torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None
347
+ rows_per_program = math.ceil(M / sm_count)
348
+ grid = (sm_count,)
349
+ layer_norm_bwd_kernel[grid](
350
+ x,
351
+ weight,
352
+ bias,
353
+ y,
354
+ dy,
355
+ dx,
356
+ _dw,
357
+ _db,
358
+ dresidual,
359
+ dresidual_in,
360
+ mean,
361
+ rstd,
362
+ x.stride(0),
363
+ 0 if not recompute_output else y.stride(0),
364
+ dy.stride(0),
365
+ dx.stride(0),
366
+ dresidual.stride(0) if dresidual is not None else 0,
367
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
368
+ M,
369
+ N,
370
+ eps,
371
+ rows_per_program,
372
+ is_rms_norm,
373
+ BLOCK_N,
374
+ dresidual is not None,
375
+ dresidual_in is not None,
376
+ weight is not None,
377
+ bias is not None,
378
+ )
379
+ dw = _dw.sum(0).to(weight.dtype) if weight is not None else None
380
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
381
+ # Don't need to compute dresidual_in separately in this case
382
+ if has_residual and dx.dtype == x.dtype:
383
+ dresidual_in = dx
384
+ return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
385
+
386
+
387
+ class LayerNormLinearQuantFn(torch.autograd.Function):
388
+
389
+ @staticmethod
390
+ @input_guard
391
+ def forward(
392
+ ctx,
393
+ x,
394
+ norm_weight,
395
+ norm_bias,
396
+ linear_weight,
397
+ linear_bias,
398
+ residual=None,
399
+ eps=1e-6,
400
+ prenorm=False,
401
+ residual_in_fp32=False,
402
+ is_rms_norm=False,
403
+ ):
404
+ x_shape_og = x.shape
405
+ # reshape input data into 2D tensor
406
+ x = x.reshape(-1, x.shape[-1])
407
+ if residual is not None:
408
+ assert residual.shape == x_shape_og
409
+ residual = residual.reshape(-1, residual.shape[-1])
410
+ residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None)
411
+ y, mean, rstd, residual_out = layer_norm_fwd_quant(
412
+ x,
413
+ norm_weight,
414
+ norm_bias,
415
+ eps,
416
+ residual,
417
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
418
+ residual_dtype=residual_dtype,
419
+ is_rms_norm=is_rms_norm,
420
+ )
421
+ y = y.reshape(x_shape_og)
422
+ dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
423
+ linear_weight = weight_quant(linear_weight).to(dtype)
424
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
425
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
426
+ # We don't store y, will be recomputed in the backward pass to save memory
427
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
428
+ ctx.x_shape_og = x_shape_og
429
+ ctx.eps = eps
430
+ ctx.is_rms_norm = is_rms_norm
431
+ ctx.has_residual = residual is not None
432
+ ctx.prenorm = prenorm
433
+ ctx.x_dtype = x.dtype
434
+ ctx.linear_bias_is_none = linear_bias is None
435
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
436
+
437
+ @staticmethod
438
+ @input_guard
439
+ def backward(ctx, dout, *args):
440
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
441
+ dout = dout.reshape(-1, dout.shape[-1])
442
+ dy = F.linear(dout, linear_weight.t())
443
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
444
+ assert dy.shape == x.shape
445
+ if ctx.prenorm:
446
+ dresidual = args[0]
447
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
448
+ assert dresidual.shape == x.shape
449
+ else:
450
+ dresidual = None
451
+ dx, dnorm_weight, dnorm_bias, dresidual_in, y = layer_norm_bwd(
452
+ dy,
453
+ x,
454
+ norm_weight,
455
+ norm_bias,
456
+ ctx.eps,
457
+ mean,
458
+ rstd,
459
+ dresidual,
460
+ ctx.has_residual,
461
+ ctx.is_rms_norm,
462
+ x_dtype=ctx.x_dtype,
463
+ recompute_output=True
464
+ )
465
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
466
+ return (
467
+ dx.reshape(ctx.x_shape_og),
468
+ dnorm_weight,
469
+ dnorm_bias,
470
+ dlinear_weight,
471
+ dlinear_bias,
472
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
473
+ None,
474
+ None,
475
+ None,
476
+ None,
477
+ )
478
+
479
+
480
+ def layer_norm_linear_quant_fn(
481
+ x,
482
+ norm_weight,
483
+ norm_bias,
484
+ linear_weight,
485
+ linear_bias,
486
+ residual=None,
487
+ eps=1e-6,
488
+ prenorm=False,
489
+ residual_in_fp32=False,
490
+ is_rms_norm=False,
491
+ ):
492
+ return LayerNormLinearQuantFn.apply(
493
+ x,
494
+ norm_weight,
495
+ norm_bias,
496
+ linear_weight,
497
+ linear_bias,
498
+ residual,
499
+ eps,
500
+ prenorm,
501
+ residual_in_fp32,
502
+ is_rms_norm,
503
+ )
504
+
505
+
506
+ def rms_norm_linear_quant(
507
+ x: torch.Tensor,
508
+ norm_weight: torch.Tensor,
509
+ norm_bias: torch.Tensor,
510
+ linear_weight: torch.Tensor,
511
+ linear_bias: torch.Tensor,
512
+ residual: torch.Tensor = None,
513
+ eps: float = 1e-5,
514
+ prenorm: bool = False,
515
+ residual_in_fp32: bool = False
516
+ ):
517
+ return layer_norm_linear_quant_fn(
518
+ x=x,
519
+ norm_weight=norm_weight,
520
+ norm_bias=norm_bias,
521
+ linear_weight=linear_weight,
522
+ linear_bias=linear_bias,
523
+ residual=residual,
524
+ eps=eps,
525
+ prenorm=prenorm,
526
+ residual_in_fp32=residual_in_fp32,
527
+ is_rms_norm=True
528
+ )
529
+
530
+
531
+ @require_version("triton>=3.0", "Triton >= 3.0 is required to do online quantization.")
532
+ def bit_linear(x, weight, bias=None, norm_weight=None, norm_bias=None, eps=1e-8):
533
+ """
534
+ A functional version of BitLinear that applies quantization to activations and weights.
535
+
536
+ Args:
537
+ x: Input tensor with shape [n, d].
538
+ weight: Weight tensor with shape [out_features, in_features].
539
+ bias: Bias tensor with shape [out_features] (optional).
540
+ norm_weight: Weight tensor for RMS normalization with shape [in_features].
541
+ norm_bias: Bias tensor for RMS normalization with shape [in_features].
542
+ eps: A small constant for numerical stability in normalization.
543
+
544
+ Returns:
545
+ Output tensor with shape [n, out_features].
546
+ """
547
+ return layer_norm_linear_quant_fn(
548
+ x,
549
+ norm_weight,
550
+ norm_bias,
551
+ weight,
552
+ bias,
553
+ is_rms_norm=True
554
+ )
555
+
556
+
557
+ class BitLinear(nn.Linear):
558
+ """
559
+ A custom linear layer that applies quantization on both activations and weights.
560
+ This is primarily for training; kernel optimization is needed for efficiency in deployment.
561
+ """
562
+
563
+ def __init__(
564
+ self,
565
+ in_features: int,
566
+ out_features: int,
567
+ bias: bool = False,
568
+ norm_eps: float = 1e-8
569
+ ):
570
+ """
571
+ Initializes the BitLinear layer.
572
+
573
+ Args:
574
+ in_features: Size of each input sample.
575
+ out_features: Size of each output sample.
576
+ bias: If set to False, the layer will not learn an additive bias. Default: True.
577
+ """
578
+ # Initialize the superclass nn.Linear with the given parameters
579
+ super(BitLinear, self).__init__(in_features, out_features, bias=bias)
580
+
581
+ self.norm = RMSNorm(in_features, eps=norm_eps)
582
+
583
+ def __repr__(self) -> str:
584
+ return f"{self.__class__.__name__}({super().extra_repr()}, norm_eps={self.norm.eps})"
585
+
586
+ def forward(self, x):
587
+ """
588
+ Overrides the forward pass to include quantization.
589
+
590
+ Args:
591
+ x: An input tensor with shape [n, d].
592
+
593
+ Returns:
594
+ An output tensor with shape [n, d].
595
+ """
596
+ # Weight tensor
597
+ w = self.weight
598
+
599
+ # Apply RMS normalization to the input
600
+ x_norm = self.norm(x)
601
+
602
+ # Apply quantization to both activations and weights
603
+ # Uses Straight-Through Estimator (STE) trick with .detach() for gradient flow
604
+ x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
605
+ w_quant = w + (weight_quant(w) - w).detach()
606
+ # Perform linear operation with quantized values
607
+ y = F.linear(x_quant, w_quant)
608
+
609
+ return y
610
+
611
+
612
+ class FusedBitLinear(BitLinear):
613
+ """
614
+ A custom linear layer that applies quantization on both activations and weights.
615
+ This is primarily for training; kernel optimization is needed for efficiency in deployment.
616
+ """
617
+
618
+ def __init__(self, in_features, out_features, bias=False):
619
+ """
620
+ Initializes the BitLinear layer.
621
+
622
+ Args:
623
+ in_features: Size of each input sample.
624
+ out_features: Size of each output sample.
625
+ bias: If set to False, the layer will not learn an additive bias. Default: True.
626
+ """
627
+ # Initialize the superclass nn.Linear with the given parameters
628
+ super(FusedBitLinear, self).__init__(in_features, out_features, bias=bias)
629
+
630
+ def forward(self, x):
631
+ return layer_norm_linear_quant_fn(
632
+ x,
633
+ self.norm.weight,
634
+ self.norm.bias,
635
+ self.weight,
636
+ self.bias,
637
+ is_rms_norm=True
638
+ )
fla/modules/fused_kl_div.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.ops.utils.op import exp, log
12
+ from fla.utils import input_guard
13
+
14
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576
15
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
16
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
17
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
18
+ MAX_FUSED_SIZE = 65536 // 2
19
+
20
+
21
+ @triton.jit
22
+ def kl_div_kernel(
23
+ logits,
24
+ target_logits,
25
+ loss,
26
+ s_logits,
27
+ s_loss,
28
+ reduction: tl.constexpr,
29
+ N: tl.constexpr,
30
+ V: tl.constexpr,
31
+ BV: tl.constexpr
32
+ ):
33
+ # https://github.com/triton-lang/triton/issues/1058
34
+ # If N*V is too large, i_n * stride will overflow out of int32, so we convert to int64
35
+ i_n = tl.program_id(0).to(tl.int64)
36
+
37
+ logits += i_n * s_logits
38
+ target_logits += i_n * s_logits
39
+
40
+ # m is the max value. use the notation from the paper
41
+ sm = float('-inf')
42
+ tm = float('-inf')
43
+ # d is the sum. use the notation from the paper
44
+ sd, td = 0.0, 0.0
45
+
46
+ NV = tl.cdiv(V, BV)
47
+ for iv in range(0, NV):
48
+ o_x = iv * BV + tl.arange(0, BV)
49
+ # for student
50
+ b_sl = tl.load(logits + o_x, mask=o_x < V, other=float('-inf'))
51
+ b_sm = tl.max(b_sl)
52
+ m_new = tl.maximum(sm, b_sm)
53
+ sd = sd * exp(sm - m_new) + tl.sum(exp(b_sl - m_new))
54
+ sm = m_new
55
+ # for teacher
56
+ b_tl = tl.load(target_logits + o_x, mask=o_x < V, other=float('-inf'))
57
+ b_tm = tl.max(b_tl)
58
+ m_new = tl.maximum(tm, b_tm)
59
+ td = td * exp(tm - m_new) + tl.sum(exp(b_tl - m_new))
60
+ tm = m_new
61
+
62
+ b_loss = 0.
63
+ # KL(y_true || y) = exp(y_true) * (log(y_true) - log(y))
64
+ for iv in range(0, NV):
65
+ o_x = iv * BV + tl.arange(0, BV)
66
+ b_sl = tl.load(logits + o_x, mask=o_x < V, other=float('-inf'))
67
+ b_tl = tl.load(target_logits + o_x, mask=o_x < V, other=float('-inf'))
68
+ b_sp_log = b_sl - sm - log(sd)
69
+ b_tp_log = b_tl - tm - log(td)
70
+ b_sp = exp(b_sp_log)
71
+ b_tp = exp(b_tp_log)
72
+ b_kl = tl.where(o_x < V, b_tp * (b_tp_log - b_sp_log), 0)
73
+ b_dl = -b_tp + b_sp
74
+ b_loss += tl.sum(b_kl)
75
+ if reduction == 'batchmean':
76
+ b_dl = b_dl / N
77
+ tl.store(logits + o_x, b_dl, mask=o_x < V)
78
+
79
+ # Normalize the loss by the number of elements if reduction is 'batchmean'
80
+ if reduction == 'batchmean':
81
+ b_loss = b_loss / N
82
+
83
+ tl.store(loss + i_n * s_loss, b_loss)
84
+
85
+
86
+ @triton.jit
87
+ def elementwise_mul_kernel(
88
+ x,
89
+ g,
90
+ N: tl.constexpr,
91
+ B: tl.constexpr
92
+ ):
93
+ """
94
+ This function multiplies each element of the tensor pointed by x with the value pointed by g.
95
+ The multiplication is performed in-place on the tensor pointed by x.
96
+
97
+ Parameters:
98
+ x:
99
+ Pointer to the input tensor.
100
+ g:
101
+ Pointer to the gradient output value.
102
+ N (int):
103
+ The number of columns in the input tensor.
104
+ B (int):
105
+ The block size for Triton operations.
106
+ """
107
+
108
+ # Get the program ID and convert it to int64 to avoid overflow
109
+ i_x = tl.program_id(0).to(tl.int64)
110
+ o_x = i_x * B + tl.arange(0, B)
111
+
112
+ # Load the gradient output value
113
+ b_g = tl.load(g)
114
+ b_x = tl.load(x + o_x, mask=o_x < N)
115
+ tl.store(x + o_x, b_x * b_g, mask=o_x < N)
116
+
117
+
118
+ def fused_kl_div_forward(
119
+ x: torch.Tensor,
120
+ target_x: torch.Tensor,
121
+ weight: torch.Tensor,
122
+ target_weight: torch.Tensor,
123
+ reduction: str = 'batchmean'
124
+ ):
125
+ device = x.device
126
+
127
+ # ideally, we would like to achieve the same memory consumption as [N, H],
128
+ # so the expected chunk size should be:
129
+ # NC = ceil(V / H)
130
+ # C = ceil(N / NC)
131
+ # for ex: N = 4096*4, V = 32000, H = 4096 ==> NC = 8, C = ceil(N / NC) = 2048
132
+ N, H, V = *x.shape, weight.shape[0]
133
+ BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
134
+ # TODO: in real cases, we may need to limit the number of chunks NC to
135
+ # ensure the precisions of accumulated gradients
136
+ NC = min(8, triton.cdiv(V, H))
137
+ C = triton.next_power_of_2(triton.cdiv(N, NC))
138
+ NC = triton.cdiv(N, C)
139
+
140
+ dx = torch.zeros_like(x, device=device)
141
+ dw = torch.zeros_like(weight, device=device) if weight is not None else None
142
+ # we use fp32 for loss accumulator
143
+ loss = torch.zeros(N, dtype=torch.float32, device=device)
144
+
145
+ for ic in range(NC):
146
+ start, end = ic * C, min((ic + 1) * C, N)
147
+ # [C, N]
148
+ c_sx = x[start:end]
149
+ c_tx = target_x[start:end]
150
+ # when doing matmul, use the original precision
151
+ # [C, V]
152
+ c_sl = F.linear(c_sx, weight)
153
+ c_tl = F.linear(c_tx, target_weight)
154
+
155
+ # unreduced loss
156
+ c_loss = loss[start:end]
157
+
158
+ # Here we calculate the gradient of c_sx in place so we can save memory.
159
+ kl_div_kernel[(c_sx.shape[0],)](
160
+ logits=c_sl,
161
+ target_logits=c_tl,
162
+ loss=c_loss,
163
+ s_logits=c_sl.stride(-2),
164
+ s_loss=c_loss.stride(-1),
165
+ reduction=reduction,
166
+ N=N,
167
+ V=V,
168
+ BV=BV,
169
+ num_warps=32
170
+ )
171
+
172
+ # gradient of logits is computed in-place by the above triton kernel and is of shape: C x V
173
+ # thus dx[start: end] should be of shape: C x H
174
+ # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
175
+ # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
176
+ # Thus, we need an additional scaling factor of (n_non_ignore/total) to scale the gradients.
177
+ # [C, H]
178
+
179
+ dx[start:end] = torch.mm(c_sl, weight)
180
+
181
+ if weight is not None:
182
+ torch.addmm(input=dw, mat1=c_sl.t(), mat2=c_sx, out=dw)
183
+
184
+ loss = loss.sum()
185
+ return loss, dx, dw
186
+
187
+
188
+ def fused_kl_div_backward(
189
+ do: torch.Tensor,
190
+ dx: torch.Tensor,
191
+ dw: torch.Tensor
192
+ ):
193
+ # If cross entropy is the last layer, do is 1.0. Skip the mul to save time
194
+ if torch.ne(do, torch.tensor(1.0, device=do.device)):
195
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
196
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
197
+ N, H = dx.shape
198
+ B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
199
+
200
+ elementwise_mul_kernel[(triton.cdiv(N * H, B),)](
201
+ x=dx,
202
+ g=do,
203
+ N=N*H,
204
+ B=B,
205
+ num_warps=32,
206
+ )
207
+
208
+ # handle dw
209
+ if dw is not None:
210
+ V, H = dw.shape
211
+ elementwise_mul_kernel[(triton.cdiv(V * H, B),)](
212
+ x=dw,
213
+ g=do,
214
+ N=V*H,
215
+ B=B,
216
+ num_warps=32,
217
+ )
218
+
219
+ return dx, dw
220
+
221
+
222
+ class FusedKLDivLossFunction(torch.autograd.Function):
223
+
224
+ @staticmethod
225
+ @input_guard
226
+ def forward(
227
+ ctx,
228
+ x: torch.Tensor,
229
+ target_x: torch.Tensor,
230
+ weight: torch.Tensor,
231
+ target_weight: torch.Tensor,
232
+ reduction: str
233
+ ):
234
+ loss, dx, dw = fused_kl_div_forward(
235
+ x=x,
236
+ target_x=target_x,
237
+ weight=weight,
238
+ target_weight=target_weight,
239
+ reduction=reduction
240
+ )
241
+ ctx.save_for_backward(dx, dw)
242
+ return loss
243
+
244
+ @staticmethod
245
+ @input_guard
246
+ def backward(ctx, do):
247
+ dx, dw = ctx.saved_tensors
248
+ dx, dw = fused_kl_div_backward(do, dx, dw)
249
+ return dx, None, dw, None, None
250
+
251
+
252
+ def fused_kl_div_loss(
253
+ x: torch.Tensor,
254
+ target_x: torch.Tensor,
255
+ weight: torch.Tensor,
256
+ target_weight: torch.Tensor,
257
+ reduction: str = 'batchmean'
258
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
259
+ """
260
+ Args:
261
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
262
+ target_x (torch.Tensor): [batch_size * seq_len, hidden_size]
263
+ weight (torch.Tensor): [vocab_size, hidden_size]
264
+ where `vocab_size` is the number of classes.
265
+ target_weight (torch.Tensor): [vocab_size, hidden_size]
266
+ where `vocab_size` is the number of classes.
267
+ reduction:
268
+ Specifies the reduction to apply to the output: 'batchmean'. Default: 'batchmean'.
269
+ Returns:
270
+ loss
271
+ """
272
+ return FusedKLDivLossFunction.apply(
273
+ x,
274
+ target_x,
275
+ weight,
276
+ target_weight,
277
+ reduction
278
+ )
279
+
280
+
281
+ class FusedKLDivLoss(nn.Module):
282
+
283
+ def __init__(
284
+ self,
285
+ reduction: str = 'batchmean'
286
+ ):
287
+ """
288
+ Args:
289
+ reduction:
290
+ Specifies the reduction to apply to the output: 'batchmean'. Default: 'batchmean'.
291
+ """
292
+ super().__init__()
293
+
294
+ assert reduction in ['batchmean'], f"reduction: {reduction} is not supported"
295
+
296
+ self.reduction = reduction
297
+
298
+ def forward(
299
+ self,
300
+ x: torch.Tensor,
301
+ target_x: torch.Tensor,
302
+ weight: torch.Tensor,
303
+ target_weight: torch.Tensor
304
+ ):
305
+ """
306
+ Args:
307
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
308
+ target_x (torch.Tensor): [batch_size * seq_len, hidden_size]
309
+ weight (torch.Tensor): [vocab_size, hidden_size]
310
+ where `vocab_size` is the number of classes.
311
+ target_weight (torch.Tensor): [vocab_size, hidden_size]
312
+ where `vocab_size` is the number of classes.
313
+ Returns:
314
+ loss
315
+ """
316
+ loss = fused_kl_div_loss(
317
+ x=x,
318
+ target_x=target_x,
319
+ weight=weight,
320
+ target_weight=target_weight,
321
+ reduction=self.reduction
322
+ )
323
+ return loss
fla/modules/fused_linear_cross_entropy.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Code adapted from
4
+ # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py
5
+
6
+ from functools import partial
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_module
16
+ from torch.distributed.tensor.parallel import ParallelStyle
17
+
18
+ from fla.ops.utils import logsumexp_fwd
19
+ from fla.ops.utils.op import exp
20
+ from fla.utils import input_guard
21
+
22
+ # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576
23
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
24
+ # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
25
+ # The optimal maximum block size depends on your hardware, your kernel, and your dtype
26
+ MAX_FUSED_SIZE = 65536 // 2
27
+
28
+
29
+ @triton.jit
30
+ def cross_entropy_kernel(
31
+ logits,
32
+ lse,
33
+ target,
34
+ loss,
35
+ total,
36
+ ignore_index,
37
+ label_smoothing: tl.constexpr,
38
+ logit_scale: tl.constexpr,
39
+ reduction: tl.constexpr,
40
+ V: tl.constexpr,
41
+ BV: tl.constexpr
42
+ ):
43
+ """
44
+ This kernel computes both cross entropy loss and the gradient of the input.
45
+ We only consider hard label + mean reduction for now.
46
+ Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
47
+
48
+ Args:
49
+ logits:
50
+ Pointer to logits tensor.
51
+ lse:
52
+ Pointer to logsumexp tensor.
53
+ target: Pointer to target tensor.
54
+ loss:
55
+ Pointer to tensor to store the loss.
56
+ V (int):
57
+ The number of columns in the input tensor.
58
+ total (int):
59
+ The number of non-ignored classes.
60
+ ignore_index (int):
61
+ The index to ignore in the target.
62
+ label_smoothing (float):
63
+ The amount of smoothing when computing the loss, where 0.0 means no smoothing.
64
+ reduction (str):
65
+ The string for the reduction to apply
66
+ BV (int):
67
+ The block size for vocab.
68
+ """
69
+
70
+ # https://github.com/triton-lang/triton/issues/1058
71
+ # If B*T*V is too large, i_n * stride will overflow out of int32, so we convert to int64
72
+ i_n = tl.program_id(0).to(tl.int64)
73
+ NV = tl.cdiv(V, BV)
74
+
75
+ # 1. Load target first because if the target is ignore_index, we can return right away
76
+ b_y = tl.load(target + i_n)
77
+
78
+ # 2. locate the start index
79
+ logits += i_n * V
80
+
81
+ if b_y == ignore_index:
82
+ # set all x as 0
83
+ for i in range(0, V, BV):
84
+ o_v = i + tl.arange(0, BV)
85
+ tl.store(logits + o_v, 0.0, mask=o_v < V)
86
+ return
87
+
88
+ # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
89
+ # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
90
+
91
+ # 3. [Online softmax] first pass: compute logsumexp
92
+ # we did this in anouter kernel
93
+ b_l = tl.load(logits + b_y) * logit_scale
94
+ b_lse = tl.load(lse + i_n)
95
+
96
+ # 4. Calculate the loss
97
+ # loss = lse - logits_l
98
+ b_loss = b_lse - b_l
99
+
100
+ # Label smoothing is a general case of normal cross entropy
101
+ # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
102
+ b_z = 0.0
103
+ eps = label_smoothing / V
104
+
105
+ # We need tl.debug_barrier() as mentioned in
106
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
107
+ tl.debug_barrier()
108
+
109
+ # 5. [Online Softmax] Second pass: compute gradients
110
+ # For 'mean' reduction, gradients are normalized by number of non-ignored elements
111
+ # dx_y = (softmax(x_y) - 1) / N
112
+ # dx_i = softmax(x_i) / N, i != y
113
+ # For label smoothing:
114
+ # dx_i = (softmax(x_y) - label_smoothing / V) / N, i != y
115
+ # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
116
+ # = dx_i - (1 - label_smoothing) / N
117
+ for iv in range(0, NV):
118
+ o_v = iv * BV + tl.arange(0, BV)
119
+ b_logits = tl.load(logits + o_v, mask=o_v < V, other=float('-inf')) * logit_scale
120
+ if label_smoothing > 0:
121
+ # scale X beforehand to avoid overflow
122
+ b_z += tl.sum(tl.where(o_v < V, -eps * b_logits, 0.0))
123
+ b_p = (exp(b_logits - b_lse) - eps) * logit_scale
124
+ if reduction == "mean":
125
+ b_p = b_p / total
126
+ tl.store(logits + o_v, b_p, mask=o_v < V)
127
+
128
+ tl.debug_barrier()
129
+
130
+ # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
131
+ # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
132
+ # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
133
+ # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
134
+ # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
135
+ # Refer to H(q', p) in section 7 of the paper:
136
+ # https://arxiv.org/pdf/1512.00567
137
+ # pytorch:
138
+ # https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
139
+ # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
140
+ if label_smoothing > 0:
141
+ b_loss = b_loss * (1 - label_smoothing) + (b_z + label_smoothing * b_lse)
142
+
143
+ # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
144
+ b_l = tl.load(logits + b_y)
145
+
146
+ # Normalize the loss by the number of non-ignored elements if reduction is "mean"
147
+ if reduction == 'mean':
148
+ b_loss = b_loss / total
149
+ b_l += (label_smoothing - 1) / total * logit_scale
150
+ else:
151
+ b_l += (label_smoothing - 1) * logit_scale
152
+
153
+ tl.store(loss + i_n, b_loss)
154
+ tl.store(logits + b_y, b_l)
155
+
156
+
157
+ @triton.jit
158
+ def elementwise_mul_kernel(
159
+ x,
160
+ g,
161
+ N: tl.constexpr,
162
+ B: tl.constexpr
163
+ ):
164
+ """
165
+ This function multiplies each element of the tensor pointed by x with the value pointed by g.
166
+ The multiplication is performed in-place on the tensor pointed by x.
167
+
168
+ Parameters:
169
+ x:
170
+ Pointer to the input tensor.
171
+ g:
172
+ Pointer to the gradient output value.
173
+ N (int):
174
+ The number of columns in the input tensor.
175
+ B (int):
176
+ The block size for Triton operations.
177
+ """
178
+
179
+ # Get the program ID and convert it to int64 to avoid overflow
180
+ i_x = tl.program_id(0).to(tl.int64)
181
+ o_x = i_x * B + tl.arange(0, B)
182
+
183
+ # Load the gradient output value
184
+ b_g = tl.load(g)
185
+ b_x = tl.load(x + o_x, mask=o_x < N)
186
+ tl.store(x + o_x, b_x * b_g, mask=o_x < N)
187
+
188
+
189
+ def fused_linear_cross_entropy_forward(
190
+ x: torch.Tensor,
191
+ target: torch.LongTensor,
192
+ weight: torch.Tensor,
193
+ bias: torch.Tensor = None,
194
+ ignore_index: int = -100,
195
+ label_smoothing: float = 0.0,
196
+ logit_scale: float = 1.0,
197
+ num_chunks: int = 8,
198
+ reduction: str = "mean"
199
+ ):
200
+ device = x.device
201
+ # inputs have shape: [N, H]
202
+ # materialized activations will have shape: [N, V]
203
+ # the increase in memory = [N, V]
204
+ # reduction can be achieved by partitioning the number of tokens N into smaller chunks.
205
+
206
+ # ideally, we would like to achieve the same memory consumption as [N, H],
207
+ # so the expected chunk size should be:
208
+ # NC = ceil(V / H)
209
+ # C = ceil(N / NC)
210
+ # for ex: N = 4096*4, V = 32000, H = 4096 ==> NC = 8, C = ceil(N / NC) = 2048
211
+ N, H, V = *x.shape, weight.shape[0]
212
+ BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
213
+ # TODO: in real cases, we may need to limit the number of chunks NC to
214
+ # ensure the precisions of accumulated gradients
215
+ NC = min(num_chunks, triton.cdiv(V, H))
216
+ C = triton.next_power_of_2(triton.cdiv(N, NC))
217
+ NC = triton.cdiv(N, C)
218
+
219
+ # [N, H]
220
+ dx = torch.zeros_like(x, device=device)
221
+ # [V, H]
222
+ dw = torch.zeros_like(weight, device=device, dtype=torch.float) if weight is not None else None
223
+ # [V]
224
+ db = torch.zeros_like(bias, device=device, dtype=torch.float) if bias is not None else None
225
+ # [N]
226
+ loss = torch.zeros(N, device=device, dtype=torch.float)
227
+
228
+ total = target.ne(ignore_index).sum().item()
229
+
230
+ for ic in range(NC):
231
+ start, end = ic * C, min((ic + 1) * C, N)
232
+ # [C, N]
233
+ c_x = x[start:end]
234
+ # when doing matmul, use the original precision
235
+ # [C, V]
236
+ c_logits = F.linear(c_x, weight, bias)
237
+ c_target = target[start:end]
238
+ # [C]
239
+ # keep lse in fp32 to maintain precision
240
+ c_lse = logsumexp_fwd(c_logits, scale=logit_scale, dtype=torch.float)
241
+
242
+ # unreduced loss
243
+ c_loss = loss[start:end]
244
+
245
+ # Here we calculate the gradient of c_logits in place so we can save memory.
246
+ cross_entropy_kernel[(c_logits.shape[0],)](
247
+ logits=c_logits,
248
+ lse=c_lse,
249
+ target=c_target,
250
+ loss=c_loss,
251
+ total=total,
252
+ ignore_index=ignore_index,
253
+ label_smoothing=label_smoothing,
254
+ logit_scale=logit_scale,
255
+ reduction=reduction,
256
+ V=V,
257
+ BV=BV,
258
+ num_warps=32
259
+ )
260
+
261
+ # gradient of logits is computed in-place by the above triton kernel and is of shape: C x V
262
+ # thus dx should be of shape: C x H
263
+ dx[start:end] = torch.mm(c_logits, weight)
264
+
265
+ # keep dw in fp32 to maintain precision
266
+ if weight is not None:
267
+ dw += c_logits.t() @ c_x
268
+
269
+ if bias is not None:
270
+ torch.add(input=db, other=c_logits.sum(0), out=db)
271
+
272
+ loss = loss.sum()
273
+ if dw is not None:
274
+ dw = dw.to(weight)
275
+ if db is not None:
276
+ db = db.to(bias)
277
+ return loss, dx, dw, db
278
+
279
+
280
+ def fused_linear_cross_entropy_backward(
281
+ do: torch.Tensor,
282
+ dx: torch.Tensor,
283
+ dw: torch.Tensor,
284
+ db: torch.Tensor
285
+ ):
286
+ # If cross entropy is the last layer, do is 1.0. Skip the mul to save time
287
+ if torch.ne(do, torch.tensor(1.0, device=do.device)):
288
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
289
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
290
+ N, H = dx.shape
291
+ B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
292
+
293
+ elementwise_mul_kernel[(triton.cdiv(N * H, B),)](
294
+ x=dx,
295
+ g=do,
296
+ N=N*H,
297
+ B=B,
298
+ num_warps=32,
299
+ )
300
+
301
+ # handle dw
302
+ if dw is not None:
303
+ V, H = dw.shape
304
+ elementwise_mul_kernel[(triton.cdiv(V * H, B),)](
305
+ x=dw,
306
+ g=do,
307
+ N=V*H,
308
+ B=B,
309
+ num_warps=32,
310
+ )
311
+
312
+ if db is not None:
313
+ V = db.shape[0]
314
+ elementwise_mul_kernel[(triton.cdiv(V, B),)](
315
+ x=db,
316
+ g=do,
317
+ N=V,
318
+ B=B,
319
+ num_warps=32,
320
+ )
321
+ return dx, dw, db
322
+
323
+
324
+ class FusedLinearCrossEntropyFunction(torch.autograd.Function):
325
+
326
+ @staticmethod
327
+ @input_guard
328
+ def forward(
329
+ ctx,
330
+ x: torch.Tensor,
331
+ target: torch.LongTensor,
332
+ weight: torch.Tensor,
333
+ bias: torch.Tensor = None,
334
+ ignore_index: int = -100,
335
+ label_smoothing: float = 0.0,
336
+ logit_scale: float = 1.0,
337
+ num_chunks: int = 8,
338
+ reduction: str = "mean"
339
+ ):
340
+ """
341
+ Fusing the last linear layer with cross-entropy loss
342
+ Reference: https://github.com/mgmalek/efficient_cross_entropy
343
+
344
+ Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
345
+ the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
346
+ compute the gradient at the forward pass. By doing so, we don't have to store the x and target
347
+ for the backward pass.
348
+
349
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
350
+ target (torch.LongTensor): [batch_size * seq_len]
351
+ where each value is in [0, vocab_size).
352
+ weight (torch.Tensor): [vocab_size, hidden_size]
353
+ where `vocab_size` is the number of classes.
354
+ bias (Optional[torch.Tensor]): [vocab_size]
355
+ where `vocab_size` is the number of classes.
356
+ ignore_index:
357
+ the index to ignore in the target.
358
+ label_smoothing:
359
+ the amount of smoothing when computing the loss, where 0.0 means no smoothing.
360
+ logit_scale: float = 1.0,
361
+ A scaling factor applied to the logits. Default: 1.0
362
+ num_chunks: int
363
+ The number of chunks to split the input tensor into for processing.
364
+ This can help optimize memory usage and computation speed.
365
+ Default: 8
366
+ reduction:
367
+ Specifies the reduction to apply to the output: 'mean' | 'sum'.
368
+ 'mean': the weighted mean of the output is taken,
369
+ 'sum': the output will be summed.
370
+ Default: 'mean'.
371
+ """
372
+ loss, dx, dw, db = fused_linear_cross_entropy_forward(
373
+ x,
374
+ target,
375
+ weight,
376
+ bias,
377
+ ignore_index,
378
+ label_smoothing,
379
+ logit_scale,
380
+ num_chunks,
381
+ reduction
382
+ )
383
+ # downcast to dtype and store for backward
384
+ ctx.save_for_backward(
385
+ dx.detach(),
386
+ dw.detach() if weight is not None else None,
387
+ db.detach() if bias is not None else None,
388
+ )
389
+ return loss
390
+
391
+ @staticmethod
392
+ @input_guard
393
+ def backward(ctx, do):
394
+ dx, dw, db = ctx.saved_tensors
395
+ dx, dw, db = fused_linear_cross_entropy_backward(do, dx, dw, db)
396
+ return dx, None, dw, db, None, None, None, None, None
397
+
398
+
399
+ def fused_linear_cross_entropy_loss(
400
+ x: torch.Tensor,
401
+ target: torch.LongTensor,
402
+ weight: torch.Tensor,
403
+ bias: torch.Tensor = None,
404
+ ignore_index: int = -100,
405
+ label_smoothing: float = 0.0,
406
+ logit_scale: float = 1.0,
407
+ num_chunks: int = 8,
408
+ reduction: str = "mean"
409
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
410
+ """
411
+ Args:
412
+ x (torch.Tensor): [batch_size * seq_len, hidden_size]
413
+ target (torch.LongTensor): [batch_size * seq_len]
414
+ where each value is in [0, vocab_size).
415
+ weight (torch.Tensor): [vocab_size, hidden_size]
416
+ where `vocab_size` is the number of classes.
417
+ bias (Optional[torch.Tensor]): [vocab_size]
418
+ where `vocab_size` is the number of classes.
419
+ ignore_index: int.
420
+ If target == ignore_index, the loss is set to 0.0.
421
+ label_smoothing: float
422
+ logit_scale: float
423
+ A scaling factor applied to the logits. Default: 1.0
424
+ num_chunks: int
425
+ The number of chunks to split the input tensor into for processing.
426
+ This can help optimize memory usage and computation speed.
427
+ Default: 8
428
+ reduction:
429
+ Specifies the reduction to apply to the output: 'mean' | 'sum'.
430
+ 'mean': the weighted mean of the output is taken,
431
+ 'sum': the output will be summed.
432
+ Default: 'mean'.
433
+ Returns:
434
+ losses: [batch,], float
435
+ """
436
+ return FusedLinearCrossEntropyFunction.apply(
437
+ x,
438
+ target,
439
+ weight,
440
+ bias,
441
+ ignore_index,
442
+ label_smoothing,
443
+ logit_scale,
444
+ num_chunks,
445
+ reduction
446
+ )
447
+
448
+
449
+ class FusedLinearCrossEntropyLoss(nn.Module):
450
+
451
+ def __init__(
452
+ self,
453
+ ignore_index: int = -100,
454
+ label_smoothing: float = 0.0,
455
+ logit_scale: float = 1.0,
456
+ num_chunks: int = 8,
457
+ reduction: str = "mean"
458
+ ):
459
+ """
460
+ Args:
461
+ ignore_index: int.
462
+ If target == ignore_index, the loss is set to 0.0.
463
+ label_smoothing: float
464
+ logit_scale: float
465
+ A scaling factor applied to the logits. Default: 1.0
466
+ num_chunks: int
467
+ The number of chunks to split the input tensor into for processing.
468
+ This can help optimize memory usage and computation speed.
469
+ Default: 8
470
+ reduction:
471
+ Specifies the reduction to apply to the output: 'mean' | 'sum'.
472
+ 'mean': the weighted mean of the output is taken,
473
+ 'sum': the output will be summed.
474
+ Default: 'mean'.
475
+ """
476
+ super().__init__()
477
+
478
+ assert reduction in ["mean", "sum"], f"reduction: {reduction} is not supported"
479
+
480
+ self.ignore_index = ignore_index
481
+ self.label_smoothing = label_smoothing
482
+ self.logit_scale = logit_scale
483
+ self.num_chunks = num_chunks
484
+ self.reduction = reduction
485
+
486
+ @torch.compiler.disable
487
+ def forward(
488
+ self,
489
+ x: torch.Tensor,
490
+ target: torch.LongTensor,
491
+ weight: torch.Tensor,
492
+ bias: Optional[torch.Tensor] = None
493
+ ):
494
+ """
495
+ Args:
496
+ x (torch.Tensor): [batch_size, seq_len, hidden_size]
497
+ target (torch.LongTensor): [batch_size, seq_len]
498
+ where each value is in [0, V).
499
+ weight (torch.Tensor): [vocab_size, hidden_size]
500
+ where `vocab_size` is the number of classes.
501
+ bias (Optional[torch.Tensor]): [vocab_size]
502
+ where `vocab_size` is the number of classes.
503
+ Returns:
504
+ loss
505
+ """
506
+ loss = fused_linear_cross_entropy_loss(
507
+ x.view(-1, x.shape[-1]),
508
+ target.view(-1),
509
+ weight=weight,
510
+ bias=bias,
511
+ ignore_index=self.ignore_index,
512
+ label_smoothing=self.label_smoothing,
513
+ logit_scale=self.logit_scale,
514
+ num_chunks=self.num_chunks,
515
+ reduction=self.reduction
516
+ )
517
+ return loss
518
+
519
+
520
+ class LinearLossParallel(ParallelStyle):
521
+ def __init__(
522
+ self,
523
+ *,
524
+ sequence_dim: int = 1,
525
+ use_local_output: bool = False,
526
+ ):
527
+ super().__init__()
528
+
529
+ self.sequence_sharding = (Shard(sequence_dim),)
530
+ self.use_local_output = use_local_output
531
+
532
+ @staticmethod
533
+ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
534
+ x, target, weight, bias = inputs
535
+
536
+ if not isinstance(x, DTensor):
537
+ # assume the input passed in already sharded on the sequence dim and create the DTensor
538
+ x = DTensor.from_local(x, device_mesh, sequence_sharding)
539
+ if x.placements != sequence_sharding:
540
+ x = x.redistribute(placements=sequence_sharding, async_op=True)
541
+ if not isinstance(target, DTensor):
542
+ target = DTensor.from_local(target, device_mesh, [Replicate()])
543
+ if target.placements != sequence_sharding:
544
+ target = target.redistribute(placements=sequence_sharding, async_op=True)
545
+
546
+ if not isinstance(weight, DTensor):
547
+ weight = DTensor.from_local(weight, device_mesh, [Replicate()])
548
+ if weight.placements != [Replicate()]:
549
+ # we replicate the weight/bias in FLCE
550
+ weight = weight.redistribute(placements=[Replicate()], async_op=True)
551
+
552
+ if bias is not None and not isinstance(bias, DTensor):
553
+ bias = DTensor.from_local(bias, device_mesh, [Replicate()])
554
+ if bias is not None and bias.placements != [Replicate()]:
555
+ bias = bias.redistribute(placements=[Replicate()], async_op=True)
556
+
557
+ return x.to_local(), target.to_local(), weight.to_local(), bias.to_local() if bias is not None else bias
558
+
559
+ @staticmethod
560
+ def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
561
+ return outputs.to_local() if use_local_output else outputs
562
+
563
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
564
+ return distribute_module(
565
+ module,
566
+ device_mesh,
567
+ partition_fn=None,
568
+ input_fn=partial(self._prepare_input_fn, self.sequence_sharding),
569
+ output_fn=partial(self._prepare_output_fn, self.use_local_output)
570
+ )
fla/modules/layernorm.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023, Tri Dao.
4
+ # https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py
5
+ # Implement residual + layer_norm / rms_norm.
6
+
7
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
8
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
9
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
10
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
11
+
12
+ from __future__ import annotations
13
+
14
+ from functools import partial
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import triton
20
+ import triton.language as tl
21
+ from einops import rearrange
22
+ from torch.distributed import DeviceMesh
23
+ from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_module
24
+ from torch.distributed.tensor.parallel import ParallelStyle
25
+
26
+ from fla.utils import get_multiprocessor_count, input_guard
27
+
28
+
29
+ def layer_norm_ref(
30
+ x: torch.Tensor,
31
+ weight: torch.Tensor,
32
+ bias: torch.Tensor,
33
+ residual: torch.Tensor = None,
34
+ eps: float = 1e-5,
35
+ prenorm: bool = False,
36
+ upcast: bool = False
37
+ ):
38
+ dtype = x.dtype
39
+ if upcast:
40
+ weight = weight.float()
41
+ bias = bias.float() if bias is not None else None
42
+ if upcast:
43
+ x = x.float()
44
+ residual = residual.float() if residual is not None else residual
45
+ if residual is not None:
46
+ x = (x + residual).to(x.dtype)
47
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
48
+ dtype
49
+ )
50
+ return out if not prenorm else (out, x)
51
+
52
+
53
+ def rms_norm_ref(
54
+ x: torch.Tensor,
55
+ weight: torch.Tensor,
56
+ bias: torch.Tensor,
57
+ residual: torch.Tensor = None,
58
+ eps: float = 1e-5,
59
+ prenorm: bool = False,
60
+ upcast: bool = False
61
+ ):
62
+ dtype = x.dtype
63
+ if upcast:
64
+ weight = weight.float()
65
+ bias = bias.float() if bias is not None else None
66
+ if upcast:
67
+ x = x.float()
68
+ residual = residual.float() if residual is not None else residual
69
+ if residual is not None:
70
+ x = (x + residual).to(x.dtype)
71
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
72
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
73
+ out = out.to(dtype)
74
+ return out if not prenorm else (out, x)
75
+
76
+
77
+ def group_norm_ref(
78
+ x: torch.Tensor,
79
+ weight: torch.Tensor,
80
+ bias: torch.Tensor,
81
+ num_groups: int,
82
+ residual: torch.Tensor = None,
83
+ eps: float = 1e-5,
84
+ is_rms_norm: bool = False,
85
+ prenorm: bool = False,
86
+ upcast: bool = False
87
+ ):
88
+ dtype = x.dtype
89
+ if upcast:
90
+ weight = weight.float()
91
+ bias = bias.float() if bias is not None else None
92
+ if upcast:
93
+ x = x.float()
94
+ residual = residual.float() if residual is not None else residual
95
+ if residual is not None:
96
+ x = (x + residual).to(x.dtype)
97
+ residual = x
98
+ x, weight = [
99
+ rearrange(data, "... (g d) -> ... g d", g=num_groups) for data in (x, weight)
100
+ ]
101
+ if bias is not None:
102
+ bias = rearrange(bias, '... (g d) -> ... g d', g=num_groups)
103
+ if not is_rms_norm:
104
+ mean = x.mean(dim=-1, keepdim=True)
105
+ x = x - mean
106
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
107
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
108
+ out = rearrange(out, "... g d -> ... (g d)")
109
+ out = out.to(dtype)
110
+ return out if not prenorm else (out, residual)
111
+
112
+
113
+ class GroupNormRef(nn.Module):
114
+
115
+ def __init__(
116
+ self,
117
+ num_groups: int,
118
+ hidden_size: int,
119
+ elementwise_affine: bool = True,
120
+ bias: bool = False,
121
+ eps: float = 1e-5,
122
+ is_rms_norm: bool = False
123
+ ) -> GroupNormRef:
124
+ super().__init__()
125
+
126
+ if hidden_size % num_groups != 0:
127
+ raise ValueError('num_channels must be divisible by num_groups')
128
+
129
+ self.num_groups = num_groups
130
+ self.hidden_size = hidden_size
131
+ self.elementwise_affine = elementwise_affine
132
+ self.eps = eps
133
+ self.is_rms_norm = is_rms_norm
134
+
135
+ self.register_parameter("weight", None)
136
+ self.register_parameter("bias", None)
137
+ if elementwise_affine:
138
+ self.weight = nn.Parameter(torch.empty(hidden_size))
139
+ if bias:
140
+ self.bias = nn.Parameter(torch.empty(hidden_size))
141
+
142
+ self.reset_parameters()
143
+
144
+ def reset_parameters(self):
145
+ if self.elementwise_affine:
146
+ nn.init.ones_(self.weight)
147
+ if self.bias is not None:
148
+ nn.init.zeros_(self.bias)
149
+
150
+ def __repr__(self) -> str:
151
+ s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}"
152
+ if not self.elementwise_affine:
153
+ s += f", elementwise_affine={self.elementwise_affine}"
154
+ if self.is_rms_norm:
155
+ s += f", is_rms_norm={self.is_rms_norm}"
156
+ s += f", eps={self.eps}"
157
+ s += ")"
158
+ return s
159
+
160
+ def forward(self, x, residual=None, prenorm=False):
161
+ return group_norm_ref(
162
+ x,
163
+ self.weight,
164
+ self.bias,
165
+ num_groups=self.num_groups,
166
+ residual=residual,
167
+ eps=self.eps,
168
+ is_rms_norm=self.is_rms_norm,
169
+ prenorm=prenorm,
170
+ upcast=True
171
+ )
172
+
173
+
174
+ @triton.autotune(
175
+ configs=[
176
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
177
+ for num_warps in [1, 2, 4, 8, 16, 32]
178
+ for num_stages in [2, 3, 4]
179
+ ],
180
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
181
+ )
182
+ @triton.jit
183
+ def layer_norm_fwd_kernel(
184
+ X, # pointer to the input
185
+ Y, # pointer to the output
186
+ W, # pointer to the weights
187
+ B, # pointer to the biases
188
+ RESIDUAL, # pointer to the residual
189
+ RESIDUAL_OUT, # pointer to the residual
190
+ Mean, # pointer to the mean
191
+ Rstd, # pointer to the 1/std
192
+ N, # number of columns in X
193
+ G, # number of groups
194
+ eps, # epsilon to avoid division by zero
195
+ IS_RMS_NORM: tl.constexpr,
196
+ BLOCK_N: tl.constexpr,
197
+ HAS_RESIDUAL: tl.constexpr,
198
+ STORE_RESIDUAL_OUT: tl.constexpr,
199
+ HAS_WEIGHT: tl.constexpr,
200
+ HAS_BIAS: tl.constexpr
201
+ ):
202
+ # Map the program id to the row of X and Y it should compute.
203
+ row = tl.program_id(0)
204
+ group = row % G
205
+ X += row * N
206
+ Y += row * N
207
+ if HAS_RESIDUAL:
208
+ RESIDUAL += row * N
209
+ if STORE_RESIDUAL_OUT:
210
+ RESIDUAL_OUT += row * N
211
+ # Compute mean and variance
212
+ cols = tl.arange(0, BLOCK_N)
213
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
214
+ if HAS_RESIDUAL:
215
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
216
+ x += residual
217
+ if STORE_RESIDUAL_OUT:
218
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
219
+ if not IS_RMS_NORM:
220
+ mean = tl.sum(x, axis=0) / N
221
+ tl.store(Mean + row, mean)
222
+ xbar = tl.where(cols < N, x - mean, 0.0)
223
+ var = tl.sum(xbar * xbar, axis=0) / N
224
+ else:
225
+ xbar = tl.where(cols < N, x, 0.0)
226
+ var = tl.sum(xbar * xbar, axis=0) / N
227
+ rstd = 1 / tl.sqrt(var + eps)
228
+ tl.store(Rstd + row, rstd)
229
+ # Normalize and apply linear transformation
230
+ mask = cols < N
231
+ if HAS_WEIGHT:
232
+ w = tl.load(W + group * N + cols, mask=mask).to(tl.float32)
233
+ if HAS_BIAS:
234
+ b = tl.load(B + group * N + cols, mask=mask).to(tl.float32)
235
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
236
+
237
+ y = tl.fma(x_hat, w, b) if HAS_WEIGHT and HAS_BIAS else \
238
+ x_hat * w if HAS_WEIGHT else \
239
+ x_hat + b if HAS_BIAS else x_hat
240
+ # Write output
241
+ y = tl.cast(y, dtype=Y.dtype.element_ty, fp_downcast_rounding="rtne")
242
+ tl.store(Y + cols, y, mask=mask)
243
+
244
+
245
+ def layer_norm_fwd(
246
+ x: torch.Tensor,
247
+ weight: torch.Tensor,
248
+ bias: torch.Tensor,
249
+ eps: float,
250
+ residual: torch.Tensor = None,
251
+ out_dtype: torch.dtype = None,
252
+ residual_dtype: torch.dtype = None,
253
+ is_rms_norm: bool = False,
254
+ num_groups: int = 1
255
+ ):
256
+ if residual is not None:
257
+ residual_dtype = residual.dtype
258
+ M, N, G = *x.shape, num_groups
259
+ if residual is not None:
260
+ assert residual.shape == (M, N)
261
+ if weight is not None:
262
+ assert weight.shape == (G * N,)
263
+ if bias is not None:
264
+ assert bias.shape == (G * N,)
265
+ # allocate output
266
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
267
+ if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
268
+ residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)
269
+ else:
270
+ residual_out = None
271
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
272
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
273
+ # Less than 64KB per feature: enqueue fused kernel
274
+ MAX_FUSED_SIZE = 65536 // x.element_size()
275
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
276
+ if N > BLOCK_N:
277
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
278
+ # heuristics for number of warps
279
+ layer_norm_fwd_kernel[(M,)](
280
+ x,
281
+ y,
282
+ weight,
283
+ bias,
284
+ residual,
285
+ residual_out,
286
+ mean,
287
+ rstd,
288
+ N,
289
+ G,
290
+ eps,
291
+ is_rms_norm,
292
+ BLOCK_N,
293
+ residual is not None,
294
+ residual_out is not None,
295
+ weight is not None,
296
+ bias is not None,
297
+ )
298
+ # residual_out is None if residual is None and residual_dtype == input_dtype
299
+ return y, mean, rstd, residual_out if residual_out is not None else x
300
+
301
+
302
+ @triton.heuristics({
303
+ "RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None
304
+ })
305
+ @triton.autotune(
306
+ configs=[
307
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
308
+ for num_warps in [1, 2, 4, 8, 16, 32]
309
+ for num_stages in [2, 3, 4]
310
+ ],
311
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
312
+ )
313
+ @triton.jit
314
+ def layer_norm_bwd_kernel(
315
+ X, # pointer to the input
316
+ W, # pointer to the weights
317
+ B, # pointer to the biases
318
+ Y, # pointer to the output to be recomputed
319
+ DY, # pointer to the output gradient
320
+ DX, # pointer to the input gradient
321
+ DW, # pointer to the partial sum of weights gradient
322
+ DB, # pointer to the partial sum of biases gradient
323
+ DRESIDUAL,
324
+ DRESIDUAL_IN,
325
+ Mean, # pointer to the mean
326
+ Rstd, # pointer to the 1/std
327
+ M, # number of rows in X
328
+ N, # number of columns in X
329
+ G, # number of groups
330
+ rows_per_program,
331
+ programs_per_group,
332
+ IS_RMS_NORM: tl.constexpr,
333
+ BLOCK_N: tl.constexpr,
334
+ HAS_DRESIDUAL: tl.constexpr,
335
+ STORE_DRESIDUAL: tl.constexpr,
336
+ HAS_WEIGHT: tl.constexpr,
337
+ HAS_BIAS: tl.constexpr,
338
+ RECOMPUTE_OUTPUT: tl.constexpr,
339
+ ):
340
+ row_block_id = tl.program_id(0)
341
+ group_id, program_id_in_group = row_block_id // programs_per_group, row_block_id % programs_per_group
342
+
343
+ row_start = group_id + program_id_in_group * G * rows_per_program
344
+ row_end = min(row_start + G * rows_per_program, M)
345
+
346
+ cols = tl.arange(0, BLOCK_N)
347
+ mask = cols < N
348
+
349
+ if HAS_WEIGHT:
350
+ w = tl.load(W + group_id * N + cols, mask=mask).to(tl.float32)
351
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
352
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
353
+ b = tl.load(B + group_id * N + cols, mask=mask, other=0.0).to(tl.float32)
354
+ if HAS_BIAS:
355
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
356
+
357
+ for row in range(row_start, row_end, G):
358
+ # Load data to SRAM
359
+ x = tl.load(X + row * N + cols, mask=mask, other=0).to(tl.float32)
360
+ dy = tl.load(DY + row * N + cols, mask=mask, other=0).to(tl.float32)
361
+ if not IS_RMS_NORM:
362
+ mean = tl.load(Mean + row)
363
+ rstd = tl.load(Rstd + row)
364
+ # Compute dx
365
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
366
+ xhat = tl.where(mask, xhat, 0.0)
367
+ if RECOMPUTE_OUTPUT:
368
+ y = xhat * w if HAS_WEIGHT else xhat
369
+ if HAS_BIAS:
370
+ y = y + b
371
+ tl.store(Y + row * N + cols, y, mask=mask)
372
+ wdy = dy
373
+ if HAS_WEIGHT:
374
+ wdy = dy * w
375
+ dw += dy * xhat
376
+ if HAS_BIAS:
377
+ db += dy
378
+ if not IS_RMS_NORM:
379
+ c1 = tl.sum(xhat * wdy, axis=0) / N
380
+ c2 = tl.sum(wdy, axis=0) / N
381
+ dx = (wdy - (xhat * c1 + c2)) * rstd
382
+ else:
383
+ c1 = tl.sum(xhat * wdy, axis=0) / N
384
+ dx = (wdy - xhat * c1) * rstd
385
+ if HAS_DRESIDUAL:
386
+ dres = tl.load(DRESIDUAL + row * N + cols, mask=mask, other=0).to(tl.float32)
387
+ dx += dres
388
+ # Write dx
389
+ dx = tl.cast(dx, dtype=DX.dtype.element_ty, fp_downcast_rounding="rtne")
390
+ if STORE_DRESIDUAL:
391
+ tl.store(DRESIDUAL_IN + row * N + cols, dx, mask=mask)
392
+ tl.store(DX + row * N + cols, dx, mask=mask)
393
+
394
+ if HAS_WEIGHT:
395
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
396
+ if HAS_BIAS:
397
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
398
+
399
+
400
+ def layer_norm_bwd(
401
+ dy: torch.Tensor,
402
+ x: torch.Tensor,
403
+ weight: torch.Tensor,
404
+ bias: torch.Tensor,
405
+ eps: float,
406
+ mean: torch.Tensor,
407
+ rstd: torch.Tensor,
408
+ dresidual: torch.Tensor = None,
409
+ has_residual: bool = False,
410
+ is_rms_norm: bool = False,
411
+ x_dtype: torch.dtype = None,
412
+ recompute_output: bool = False,
413
+ num_groups: int = 1
414
+ ):
415
+ M, N, G = *x.shape, num_groups
416
+ assert dy.shape == (M, N)
417
+ if dresidual is not None:
418
+ assert dresidual.shape == (M, N)
419
+ if weight is not None:
420
+ assert weight.shape == (G * N,)
421
+ if bias is not None:
422
+ assert bias.shape == (G * N,)
423
+ # allocate output
424
+ dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)
425
+ dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
426
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
427
+
428
+ # Less than 64KB per feature: enqueue fused kernel
429
+ MAX_FUSED_SIZE = 65536 // x.element_size()
430
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
431
+ if N > BLOCK_N:
432
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
433
+ # each program handles one group only
434
+ S = triton.cdiv(get_multiprocessor_count(x.device.index), G) * G
435
+ dw = torch.empty((S, N), dtype=torch.float32, device=weight.device) if weight is not None else None
436
+ db = torch.empty((S, N), dtype=torch.float32, device=bias.device) if bias is not None else None
437
+ rows_per_program = triton.cdiv(M, S)
438
+ programs_per_group = S // G
439
+ grid = (S,)
440
+ layer_norm_bwd_kernel[grid](
441
+ x,
442
+ weight,
443
+ bias,
444
+ y,
445
+ dy,
446
+ dx,
447
+ dw,
448
+ db,
449
+ dresidual,
450
+ dresidual_in,
451
+ mean,
452
+ rstd,
453
+ M,
454
+ N,
455
+ G,
456
+ rows_per_program,
457
+ programs_per_group,
458
+ is_rms_norm,
459
+ BLOCK_N,
460
+ dresidual is not None,
461
+ dresidual_in is not None,
462
+ weight is not None,
463
+ bias is not None,
464
+ )
465
+ dw = dw.view(G, -1, N).sum(1).to(weight).view_as(weight) if weight is not None else None
466
+ db = db.view(G, -1, N).sum(1).to(bias).view_as(bias) if bias is not None else None
467
+ # Don't need to compute dresidual_in separately in this case
468
+ if has_residual and dx.dtype == x.dtype:
469
+ dresidual_in = dx
470
+ return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
471
+
472
+
473
+ class LayerNormFunction(torch.autograd.Function):
474
+
475
+ @staticmethod
476
+ @input_guard
477
+ def forward(
478
+ ctx,
479
+ x,
480
+ weight,
481
+ bias,
482
+ residual=None,
483
+ eps=1e-5,
484
+ prenorm=False,
485
+ residual_in_fp32=False,
486
+ is_rms_norm=False,
487
+ num_groups=1
488
+ ):
489
+ x_shape_og = x.shape
490
+
491
+ if x.shape[-1] % num_groups != 0:
492
+ raise ValueError('num_channels must be divisible by num_groups')
493
+ # reshape input data into 2D tensor
494
+ x = x.reshape(-1, (x.shape[-1] // num_groups))
495
+ if residual is not None:
496
+ assert residual.shape == x_shape_og
497
+ residual = residual.reshape_as(x)
498
+ residual_dtype = (
499
+ residual.dtype
500
+ if residual is not None
501
+ else (torch.float32 if residual_in_fp32 else None)
502
+ )
503
+ y, mean, rstd, residual_out = layer_norm_fwd(
504
+ x,
505
+ weight,
506
+ bias,
507
+ eps,
508
+ residual,
509
+ residual_dtype=residual_dtype,
510
+ is_rms_norm=is_rms_norm,
511
+ num_groups=num_groups
512
+ )
513
+ ctx.save_for_backward(residual_out, weight, bias, mean, rstd)
514
+ ctx.x_shape_og = x_shape_og
515
+ ctx.eps = eps
516
+ ctx.is_rms_norm = is_rms_norm
517
+ ctx.num_groups = num_groups
518
+ ctx.has_residual = residual is not None
519
+ ctx.prenorm = prenorm
520
+ ctx.x_dtype = x.dtype
521
+ y = y.reshape(x_shape_og)
522
+ return y if not prenorm else (y, residual_out.reshape(x_shape_og))
523
+
524
+ @staticmethod
525
+ @input_guard
526
+ def backward(ctx, dy, *args):
527
+ x, weight, bias, mean, rstd = ctx.saved_tensors
528
+ dy = dy.reshape(-1, (dy.shape[-1] // ctx.num_groups))
529
+ assert dy.shape == x.shape
530
+ if ctx.prenorm:
531
+ dresidual = args[0]
532
+ dresidual = dresidual.reshape(-1, x.shape[-1])
533
+ assert dresidual.shape == x.shape
534
+ else:
535
+ dresidual = None
536
+ dx, dw, db, dresidual_in = layer_norm_bwd(
537
+ dy,
538
+ x,
539
+ weight,
540
+ bias,
541
+ ctx.eps,
542
+ mean,
543
+ rstd,
544
+ dresidual,
545
+ ctx.has_residual,
546
+ ctx.is_rms_norm,
547
+ x_dtype=ctx.x_dtype,
548
+ num_groups=ctx.num_groups
549
+ )
550
+ return (
551
+ dx.reshape(ctx.x_shape_og),
552
+ dw,
553
+ db,
554
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
555
+ None,
556
+ None,
557
+ None,
558
+ None,
559
+ None
560
+ )
561
+
562
+
563
+ def layer_norm(
564
+ x: torch.Tensor,
565
+ weight: torch.Tensor,
566
+ bias: torch.Tensor,
567
+ residual: torch.Tensor = None,
568
+ eps: float = 1e-5,
569
+ prenorm: bool = False,
570
+ residual_in_fp32: bool = False,
571
+ is_rms_norm: bool = False
572
+ ):
573
+ return LayerNormFunction.apply(
574
+ x,
575
+ weight,
576
+ bias,
577
+ residual,
578
+ eps,
579
+ prenorm,
580
+ residual_in_fp32,
581
+ is_rms_norm
582
+ )
583
+
584
+
585
+ def group_norm(
586
+ x: torch.Tensor,
587
+ weight: torch.Tensor,
588
+ bias: torch.Tensor,
589
+ residual: torch.Tensor = None,
590
+ eps: float = 1e-5,
591
+ prenorm: bool = False,
592
+ residual_in_fp32: bool = False,
593
+ is_rms_norm: bool = False,
594
+ num_groups: int = 1
595
+ ):
596
+ return LayerNormFunction.apply(
597
+ x,
598
+ weight,
599
+ bias,
600
+ residual,
601
+ eps,
602
+ prenorm,
603
+ residual_in_fp32,
604
+ is_rms_norm,
605
+ num_groups
606
+ )
607
+
608
+
609
+ def rms_norm(
610
+ x: torch.Tensor,
611
+ weight: torch.Tensor,
612
+ bias: torch.Tensor,
613
+ residual: torch.Tensor = None,
614
+ eps: float = 1e-5,
615
+ prenorm: bool = False,
616
+ residual_in_fp32: bool = False
617
+ ):
618
+ return LayerNormFunction.apply(
619
+ x,
620
+ weight,
621
+ bias,
622
+ residual,
623
+ eps,
624
+ prenorm,
625
+ residual_in_fp32,
626
+ True
627
+ )
628
+
629
+
630
+ def layer_norm_linear(
631
+ x: torch.Tensor,
632
+ norm_weight: torch.Tensor,
633
+ norm_bias: torch.Tensor,
634
+ linear_weight: torch.Tensor,
635
+ linear_bias: torch.Tensor,
636
+ residual: torch.Tensor = None,
637
+ eps: float = 1e-5,
638
+ prenorm: bool = False,
639
+ residual_in_fp32: bool = False,
640
+ is_rms_norm: bool = False,
641
+ num_groups: int = 1
642
+ ):
643
+ return LayerNormLinearFunction.apply(
644
+ x,
645
+ norm_weight,
646
+ norm_bias,
647
+ linear_weight,
648
+ linear_bias,
649
+ residual,
650
+ eps,
651
+ prenorm,
652
+ residual_in_fp32,
653
+ is_rms_norm,
654
+ num_groups
655
+ )
656
+
657
+
658
+ def rms_norm_linear(
659
+ x: torch.Tensor,
660
+ norm_weight: torch.Tensor,
661
+ norm_bias: torch.Tensor,
662
+ linear_weight: torch.Tensor,
663
+ linear_bias: torch.Tensor,
664
+ residual: torch.Tensor = None,
665
+ eps: float = 1e-5,
666
+ prenorm: bool = False,
667
+ residual_in_fp32: bool = False
668
+ ):
669
+ return layer_norm_linear(
670
+ x=x,
671
+ norm_weight=norm_weight,
672
+ norm_bias=norm_bias,
673
+ linear_weight=linear_weight,
674
+ linear_bias=linear_bias,
675
+ residual=residual,
676
+ eps=eps,
677
+ prenorm=prenorm,
678
+ residual_in_fp32=residual_in_fp32,
679
+ is_rms_norm=True
680
+ )
681
+
682
+
683
+ def group_norm_linear(
684
+ x: torch.Tensor,
685
+ norm_weight: torch.Tensor,
686
+ norm_bias: torch.Tensor,
687
+ linear_weight: torch.Tensor,
688
+ linear_bias: torch.Tensor,
689
+ residual: torch.Tensor = None,
690
+ eps: float = 1e-5,
691
+ prenorm: bool = False,
692
+ residual_in_fp32: bool = False,
693
+ is_rms_norm: bool = False,
694
+ num_groups: int = 1
695
+ ):
696
+ return layer_norm_linear(
697
+ x=x,
698
+ norm_weight=norm_weight,
699
+ norm_bias=norm_bias,
700
+ linear_weight=linear_weight,
701
+ linear_bias=linear_bias,
702
+ residual=residual,
703
+ eps=eps,
704
+ prenorm=prenorm,
705
+ residual_in_fp32=residual_in_fp32,
706
+ is_rms_norm=is_rms_norm,
707
+ num_groups=num_groups
708
+ )
709
+
710
+
711
+ class LayerNorm(nn.Module):
712
+
713
+ def __init__(
714
+ self,
715
+ hidden_size: int,
716
+ elementwise_affine: bool = True,
717
+ bias: bool = False,
718
+ eps: float = 1e-5
719
+ ) -> LayerNorm:
720
+ super().__init__()
721
+
722
+ self.hidden_size = hidden_size
723
+ self.elementwise_affine = elementwise_affine
724
+ self.eps = eps
725
+
726
+ self.register_parameter("weight", None)
727
+ self.register_parameter("bias", None)
728
+ if elementwise_affine:
729
+ self.weight = nn.Parameter(torch.empty(hidden_size))
730
+ if bias:
731
+ self.bias = nn.Parameter(torch.empty(hidden_size))
732
+
733
+ self.reset_parameters()
734
+
735
+ def reset_parameters(self):
736
+ if self.elementwise_affine:
737
+ nn.init.ones_(self.weight)
738
+ if self.bias is not None:
739
+ nn.init.zeros_(self.bias)
740
+
741
+ def __repr__(self) -> str:
742
+ s = f"{self.__class__.__name__}({self.hidden_size}"
743
+ if not self.elementwise_affine:
744
+ s += f", elementwise_affine={self.elementwise_affine}"
745
+ s += f", eps={self.eps}"
746
+ s += ")"
747
+ return s
748
+
749
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
750
+ return layer_norm(
751
+ x,
752
+ self.weight,
753
+ self.bias,
754
+ residual=residual,
755
+ eps=self.eps,
756
+ prenorm=prenorm,
757
+ residual_in_fp32=residual_in_fp32
758
+ )
759
+
760
+
761
+ class GroupNorm(nn.Module):
762
+
763
+ def __init__(
764
+ self,
765
+ num_groups: int,
766
+ hidden_size: int,
767
+ elementwise_affine: bool = True,
768
+ bias: bool = False,
769
+ eps: float = 1e-5,
770
+ is_rms_norm: bool = False
771
+ ) -> GroupNorm:
772
+ super().__init__()
773
+
774
+ if hidden_size % num_groups != 0:
775
+ raise ValueError('num_channels must be divisible by num_groups')
776
+
777
+ self.num_groups = num_groups
778
+ self.hidden_size = hidden_size
779
+ self.elementwise_affine = elementwise_affine
780
+ self.eps = eps
781
+ self.is_rms_norm = is_rms_norm
782
+
783
+ self.register_parameter("weight", None)
784
+ self.register_parameter("bias", None)
785
+ if elementwise_affine:
786
+ self.weight = nn.Parameter(torch.empty(hidden_size))
787
+ if bias:
788
+ self.bias = nn.Parameter(torch.empty(hidden_size))
789
+
790
+ self.reset_parameters()
791
+
792
+ def reset_parameters(self):
793
+ if self.elementwise_affine:
794
+ nn.init.ones_(self.weight)
795
+ if self.bias is not None:
796
+ nn.init.zeros_(self.bias)
797
+
798
+ def __repr__(self) -> str:
799
+ s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}"
800
+ if not self.elementwise_affine:
801
+ s += f", elementwise_affine={self.elementwise_affine}"
802
+ if self.is_rms_norm:
803
+ s += f", is_rms_norm={self.is_rms_norm}"
804
+ s += f", eps={self.eps}"
805
+ s += ")"
806
+ return s
807
+
808
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
809
+ return group_norm(
810
+ x,
811
+ self.weight,
812
+ self.bias,
813
+ residual=residual,
814
+ eps=self.eps,
815
+ prenorm=prenorm,
816
+ residual_in_fp32=residual_in_fp32,
817
+ is_rms_norm=self.is_rms_norm,
818
+ num_groups=self.num_groups
819
+ )
820
+
821
+
822
+ class RMSNorm(nn.Module):
823
+
824
+ def __init__(
825
+ self,
826
+ hidden_size: int,
827
+ elementwise_affine: bool = True,
828
+ bias: bool = False,
829
+ eps: float = 1e-5
830
+ ) -> RMSNorm:
831
+ super().__init__()
832
+
833
+ self.hidden_size = hidden_size
834
+ self.elementwise_affine = elementwise_affine
835
+ self.eps = eps
836
+
837
+ self.register_parameter("weight", None)
838
+ self.register_parameter("bias", None)
839
+ if elementwise_affine:
840
+ self.weight = nn.Parameter(torch.empty(hidden_size))
841
+ if bias:
842
+ self.bias = nn.Parameter(torch.empty(hidden_size))
843
+
844
+ self.reset_parameters()
845
+
846
+ def reset_parameters(self):
847
+ if self.elementwise_affine:
848
+ nn.init.ones_(self.weight)
849
+ if self.bias is not None:
850
+ nn.init.zeros_(self.bias)
851
+
852
+ def __repr__(self) -> str:
853
+ s = f"{self.__class__.__name__}({self.hidden_size}"
854
+ if not self.elementwise_affine:
855
+ s += f", elementwise_affine={self.elementwise_affine}"
856
+ s += f", eps={self.eps}"
857
+ s += ")"
858
+ return s
859
+
860
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
861
+ return rms_norm(
862
+ x,
863
+ self.weight,
864
+ self.bias,
865
+ residual=residual,
866
+ eps=self.eps,
867
+ prenorm=prenorm,
868
+ residual_in_fp32=residual_in_fp32,
869
+ )
870
+
871
+
872
+ class LayerNormLinearFunction(torch.autograd.Function):
873
+
874
+ @staticmethod
875
+ @input_guard
876
+ def forward(
877
+ ctx,
878
+ x,
879
+ norm_weight,
880
+ norm_bias,
881
+ linear_weight,
882
+ linear_bias,
883
+ residual=None,
884
+ eps=1e-5,
885
+ prenorm=False,
886
+ residual_in_fp32=False,
887
+ is_rms_norm=False,
888
+ num_groups=1
889
+ ):
890
+ x_shape_og = x.shape
891
+
892
+ if x.shape[-1] % num_groups != 0:
893
+ raise ValueError('num_channels must be divisible by num_groups')
894
+ # reshape input data into 2D tensor
895
+ x = x.reshape(-1, (x.shape[-1] // num_groups))
896
+ if residual is not None:
897
+ assert residual.shape == x_shape_og
898
+ residual = residual.reshape_as(x)
899
+ residual_dtype = (
900
+ residual.dtype
901
+ if residual is not None
902
+ else (torch.float32 if residual_in_fp32 else None)
903
+ )
904
+ y, mean, rstd, residual_out = layer_norm_fwd(
905
+ x,
906
+ norm_weight,
907
+ norm_bias,
908
+ eps,
909
+ residual,
910
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
911
+ residual_dtype=residual_dtype,
912
+ is_rms_norm=is_rms_norm,
913
+ num_groups=num_groups
914
+ )
915
+ y = y.reshape(x_shape_og)
916
+ dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
917
+ linear_weight = linear_weight.to(dtype)
918
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
919
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
920
+ # We don't store y, will be recomputed in the backward pass to save memory
921
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
922
+ ctx.x_shape_og = x_shape_og
923
+ ctx.eps = eps
924
+ ctx.is_rms_norm = is_rms_norm
925
+ ctx.num_groups = num_groups
926
+ ctx.has_residual = residual is not None
927
+ ctx.prenorm = prenorm
928
+ ctx.x_dtype = x.dtype
929
+ ctx.linear_bias_is_none = linear_bias is None
930
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
931
+
932
+ @staticmethod
933
+ @input_guard
934
+ def backward(ctx, dout, *args):
935
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
936
+ dout = dout.reshape(-1, dout.shape[-1])
937
+ dy = F.linear(dout, linear_weight.t())
938
+ dy = dy.reshape(-1, (dy.shape[-1] // ctx.num_groups))
939
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
940
+ assert dy.shape == x.shape
941
+ if ctx.prenorm:
942
+ dresidual = args[0]
943
+ dresidual = dresidual.reshape(-1, x.shape[-1])
944
+ assert dresidual.shape == x.shape
945
+ else:
946
+ dresidual = None
947
+ dx, dnorm_weight, dnorm_bias, dresidual_in, y = layer_norm_bwd(
948
+ dy,
949
+ x,
950
+ norm_weight,
951
+ norm_bias,
952
+ ctx.eps,
953
+ mean,
954
+ rstd,
955
+ dresidual,
956
+ ctx.has_residual,
957
+ ctx.is_rms_norm,
958
+ x_dtype=ctx.x_dtype,
959
+ recompute_output=True,
960
+ num_groups=ctx.num_groups
961
+ )
962
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y.view(-1, linear_weight.shape[-1]))
963
+ return (
964
+ dx.reshape(ctx.x_shape_og),
965
+ dnorm_weight,
966
+ dnorm_bias,
967
+ dlinear_weight,
968
+ dlinear_bias,
969
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
970
+ None,
971
+ None,
972
+ None,
973
+ None,
974
+ None
975
+ )
976
+
977
+
978
+ class LayerNormLinear(nn.Module):
979
+
980
+ def __init__(
981
+ self,
982
+ hidden_size,
983
+ elementwise_affine: bool = True,
984
+ bias: bool = False,
985
+ eps: float = 1e-5
986
+ ) -> LayerNormLinear:
987
+ super().__init__()
988
+
989
+ self.hidden_size = hidden_size
990
+ self.elementwise_affine = elementwise_affine
991
+ self.eps = eps
992
+
993
+ self.register_parameter("weight", None)
994
+ self.register_parameter("bias", None)
995
+ if elementwise_affine:
996
+ self.weight = nn.Parameter(torch.empty(hidden_size))
997
+ if bias:
998
+ self.bias = nn.Parameter(torch.empty(hidden_size))
999
+
1000
+ self.reset_parameters()
1001
+
1002
+ def reset_parameters(self):
1003
+ if self.elementwise_affine:
1004
+ nn.init.ones_(self.weight)
1005
+ if self.bias is not None:
1006
+ nn.init.zeros_(self.bias)
1007
+
1008
+ def __repr__(self) -> str:
1009
+ s = f"{self.__class__.__name__}({self.hidden_size}"
1010
+ if not self.elementwise_affine:
1011
+ s += f", elementwise_affine={self.elementwise_affine}"
1012
+ s += f", eps={self.eps}"
1013
+ s += ")"
1014
+ return s
1015
+
1016
+ def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
1017
+ return layer_norm_linear(
1018
+ x=x,
1019
+ norm_weight=self.weight,
1020
+ norm_bias=self.bias,
1021
+ linear_weight=weight,
1022
+ linear_bias=bias,
1023
+ residual=residual,
1024
+ eps=self.eps,
1025
+ prenorm=prenorm,
1026
+ residual_in_fp32=residual_in_fp32,
1027
+ is_rms_norm=False
1028
+ )
1029
+
1030
+
1031
+ class GroupNormLinear(nn.Module):
1032
+
1033
+ def __init__(
1034
+ self,
1035
+ num_groups: int,
1036
+ hidden_size: int,
1037
+ elementwise_affine: bool = True,
1038
+ bias: bool = False,
1039
+ eps: float = 1e-5,
1040
+ is_rms_norm: bool = False
1041
+ ) -> GroupNormLinear:
1042
+ super().__init__()
1043
+
1044
+ if hidden_size % num_groups != 0:
1045
+ raise ValueError('num_channels must be divisible by num_groups')
1046
+
1047
+ self.num_groups = num_groups
1048
+ self.hidden_size = hidden_size
1049
+ self.elementwise_affine = elementwise_affine
1050
+ self.eps = eps
1051
+ self.is_rms_norm = is_rms_norm
1052
+
1053
+ self.register_parameter("weight", None)
1054
+ self.register_parameter("bias", None)
1055
+ if elementwise_affine:
1056
+ self.weight = nn.Parameter(torch.empty(hidden_size))
1057
+ if bias:
1058
+ self.bias = nn.Parameter(torch.empty(hidden_size))
1059
+
1060
+ self.reset_parameters()
1061
+
1062
+ def reset_parameters(self):
1063
+ if self.elementwise_affine:
1064
+ nn.init.ones_(self.weight)
1065
+ if self.bias is not None:
1066
+ nn.init.zeros_(self.bias)
1067
+
1068
+ def __repr__(self) -> str:
1069
+ s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}"
1070
+ if not self.elementwise_affine:
1071
+ s += f", elementwise_affine={self.elementwise_affine}"
1072
+ if self.is_rms_norm:
1073
+ s += f", is_rms_norm={self.is_rms_norm}"
1074
+ s += f", eps={self.eps}"
1075
+ s += ")"
1076
+ return s
1077
+
1078
+ def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
1079
+ return layer_norm_linear(
1080
+ x=x,
1081
+ norm_weight=self.weight,
1082
+ norm_bias=self.bias,
1083
+ linear_weight=weight,
1084
+ linear_bias=bias,
1085
+ residual=residual,
1086
+ eps=self.eps,
1087
+ prenorm=prenorm,
1088
+ residual_in_fp32=residual_in_fp32,
1089
+ is_rms_norm=self.is_rms_norm,
1090
+ num_groups=self.num_groups
1091
+ )
1092
+
1093
+
1094
+ class RMSNormLinear(nn.Module):
1095
+
1096
+ def __init__(
1097
+ self,
1098
+ hidden_size,
1099
+ elementwise_affine: bool = True,
1100
+ bias: bool = False,
1101
+ eps: float = 1e-5
1102
+ ) -> RMSNormLinear:
1103
+ super().__init__()
1104
+
1105
+ self.hidden_size = hidden_size
1106
+ self.elementwise_affine = elementwise_affine
1107
+ self.eps = eps
1108
+
1109
+ self.register_parameter("weight", None)
1110
+ self.register_parameter("bias", None)
1111
+ if elementwise_affine:
1112
+ self.weight = nn.Parameter(torch.empty(hidden_size))
1113
+ if bias:
1114
+ self.bias = nn.Parameter(torch.empty(hidden_size))
1115
+
1116
+ self.reset_parameters()
1117
+
1118
+ def reset_parameters(self):
1119
+ if self.elementwise_affine:
1120
+ nn.init.ones_(self.weight)
1121
+ if self.bias is not None:
1122
+ nn.init.zeros_(self.bias)
1123
+
1124
+ def __repr__(self) -> str:
1125
+ s = f"{self.__class__.__name__}({self.hidden_size}"
1126
+ if not self.elementwise_affine:
1127
+ s += f", elementwise_affine={self.elementwise_affine}"
1128
+ s += f", eps={self.eps}"
1129
+ s += ")"
1130
+ return s
1131
+
1132
+ def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False):
1133
+ return layer_norm_linear(
1134
+ x=x,
1135
+ norm_weight=self.weight,
1136
+ norm_bias=self.bias,
1137
+ linear_weight=weight,
1138
+ linear_bias=bias,
1139
+ residual=residual,
1140
+ eps=self.eps,
1141
+ prenorm=prenorm,
1142
+ residual_in_fp32=residual_in_fp32,
1143
+ is_rms_norm=True
1144
+ )
1145
+
1146
+
1147
+ class NormParallel(ParallelStyle):
1148
+
1149
+ def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False):
1150
+ super().__init__()
1151
+ self.sequence_sharding = (Shard(sequence_dim),)
1152
+ self.use_local_output = use_local_output
1153
+
1154
+ def _replicate_module_fn(
1155
+ self, name: str, module: nn.Module, device_mesh: DeviceMesh
1156
+ ):
1157
+ for p_name, param in module.named_parameters():
1158
+ # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow
1159
+ # us to simply just use from_local
1160
+ replicated_param = torch.nn.Parameter(
1161
+ DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
1162
+ )
1163
+ module.register_parameter(p_name, replicated_param)
1164
+
1165
+ @staticmethod
1166
+ def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
1167
+ input_tensor = inputs[0]
1168
+ if isinstance(input_tensor, DTensor):
1169
+ # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it
1170
+ if input_tensor.placements != sequence_sharding:
1171
+ input_tensor = input_tensor.redistribute(
1172
+ placements=sequence_sharding, async_op=True
1173
+ )
1174
+ return input_tensor
1175
+ elif isinstance(input_tensor, torch.Tensor):
1176
+ # assume the input passed in already sharded on the sequence dim and create the DTensor
1177
+ return DTensor.from_local(
1178
+ input_tensor, device_mesh, sequence_sharding, run_check=False
1179
+ )
1180
+ else:
1181
+ raise ValueError(
1182
+ f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}"
1183
+ )
1184
+
1185
+ @staticmethod
1186
+ def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
1187
+ return outputs.to_local() if use_local_output else outputs
1188
+
1189
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
1190
+ return distribute_module(
1191
+ module,
1192
+ device_mesh,
1193
+ self._replicate_module_fn,
1194
+ partial(self._prepare_input_fn, self.sequence_sharding),
1195
+ partial(self._prepare_output_fn, self.use_local_output),
1196
+ )
fla/modules/parallel.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch.nn as nn
7
+ from torch.distributed import DeviceMesh
8
+ from torch.distributed.tensor import DTensor, distribute_module
9
+ from torch.distributed.tensor.parallel import ParallelStyle
10
+ from torch.distributed.tensor.placement_types import Placement
11
+
12
+
13
+ class PrepareModuleWeight(ParallelStyle):
14
+ def __init__(self, *, layouts: Optional[Placement] = None):
15
+ super().__init__()
16
+ self.layouts = layouts
17
+
18
+ def _replicate_module_fn(
19
+ self,
20
+ name: str,
21
+ module: nn.Module,
22
+ device_mesh: DeviceMesh
23
+ ):
24
+ for p_name, param in module.named_parameters():
25
+ replicated_param = nn.Parameter(
26
+ DTensor.from_local(param, device_mesh, [self.layouts], run_check=False)
27
+ )
28
+ module.register_parameter(p_name, replicated_param)
29
+
30
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
31
+ return distribute_module(
32
+ module,
33
+ device_mesh,
34
+ partition_fn=self._replicate_module_fn,
35
+ input_fn=None,
36
+ output_fn=None
37
+ )
fla/modules/rotary.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023, Tri Dao.
4
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py
5
+
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import triton
11
+ import triton.language as tl
12
+ from einops import rearrange, repeat
13
+
14
+ from fla.utils import get_multiprocessor_count, input_guard
15
+
16
+
17
+ def rotate_half(x, interleaved=False):
18
+ if not interleaved:
19
+ x1, x2 = x.chunk(2, dim=-1)
20
+ return torch.cat((-x2, x1), dim=-1)
21
+ else:
22
+ x1, x2 = x[..., ::2], x[..., 1::2]
23
+ return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
24
+
25
+
26
+ def rotary_embedding_ref(x, cos, sin, interleaved=False):
27
+ ro_dim = cos.shape[-1] * 2
28
+ assert ro_dim <= x.shape[-1]
29
+ cos = repeat(cos, '... d -> ... 1 (2 d)' if not interleaved else '... d -> ... 1 (d 2)')
30
+ sin = repeat(sin, '... d -> ... 1 (2 d)' if not interleaved else '... d -> ... 1 (d 2)')
31
+ return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], -1)
32
+
33
+
34
+ @triton.autotune(
35
+ configs=[
36
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
37
+ for num_warps in [2, 4, 8, 16, 32]
38
+ for num_stages in [2, 3, 4]
39
+ ],
40
+ key=['B', 'H', 'D', 'INTERLEAVED'],
41
+ )
42
+ @triton.jit
43
+ def rotary_embedding_kernel(
44
+ x,
45
+ cos,
46
+ sin,
47
+ y,
48
+ cu_seqlens,
49
+ seq_offsets, # this could be int or a pointer
50
+ # Matrix dimensions
51
+ B: tl.constexpr,
52
+ T: tl.constexpr,
53
+ H: tl.constexpr,
54
+ D: tl.constexpr,
55
+ R: tl.constexpr,
56
+ TR: tl.constexpr,
57
+ BT: tl.constexpr,
58
+ BD: tl.constexpr,
59
+ IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
60
+ IS_VARLEN: tl.constexpr,
61
+ INTERLEAVED: tl.constexpr,
62
+ CONJUGATE: tl.constexpr
63
+ ):
64
+ i_t, i_b, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2)
65
+
66
+ if not IS_VARLEN:
67
+ x = x + i_b * T*H*D + i_h * D
68
+ y = y + i_b * T*H*D + i_h * D
69
+ else:
70
+ bos, eos = tl.load(cu_seqlens + i_b), tl.load(cu_seqlens + i_b + 1)
71
+ T = eos - bos
72
+ x = x + bos * H*D + i_h * D
73
+ y = y + bos * H*D + i_h * D
74
+
75
+ if i_t * BT >= T:
76
+ return
77
+
78
+ o_t = i_t * BT + tl.arange(0, BT)
79
+ if not IS_SEQLEN_OFFSETS_TENSOR:
80
+ o_cs = o_t + seq_offsets
81
+ else:
82
+ o_cs = o_t + tl.load(seq_offsets + i_b)
83
+
84
+ if not INTERLEAVED:
85
+ # Load the 1st and 2nd halves of x, do calculation, then store to 1st and 2nd halves of out
86
+ o_r = tl.arange(0, BD // 2)
87
+ p_x = x + o_t[:, None] * H*D + o_r[None, :]
88
+ p_cos = cos + (o_cs[:, None] * R + o_r[None, :])
89
+ p_sin = sin + (o_cs[:, None] * R + o_r[None, :])
90
+ mask = (o_t[:, None] >= 0) & (o_t[:, None] < T) & (o_r[None, :] < R)
91
+
92
+ b_cos = tl.load(p_cos, mask=mask, other=1.0).to(tl.float32)
93
+ b_sin = tl.load(p_sin, mask=mask, other=0.0).to(tl.float32)
94
+ b_x0 = tl.load(p_x, mask=mask, other=0.0).to(tl.float32)
95
+ b_x1 = tl.load(p_x + R, mask=mask, other=0.0).to(tl.float32)
96
+ if CONJUGATE:
97
+ b_sin = -b_sin
98
+ b_o0 = b_x0 * b_cos - b_x1 * b_sin
99
+ b_o1 = b_x0 * b_sin + b_x1 * b_cos
100
+ # write back result
101
+ p_y = y + (o_t[:, None] * H*D + o_r[None, :])
102
+ tl.store(p_y, b_o0, mask=mask)
103
+ tl.store(p_y + R, b_o1, mask=mask)
104
+ else:
105
+ # We don't want to load x[0, 2, 4, ...] and x[1, 3, 5, ...] separately since both are slow.
106
+ # Instead, we load x0 = x[0, 1, 2, 3, ...] and x1 = x[1, 0, 3, 2, ...].
107
+ # Loading x0 will be fast but x1 will be slow.
108
+ # Then we load cos = cos[0, 0, 1, 1, ...] and sin = sin[0, 0, 1, 1, ...].
109
+ # Then we do the calculation and use tl.where to pick put the right outputs for the even
110
+ # and for the odd indices.
111
+ o_d = tl.arange(0, BD)
112
+ o_d_swap = o_d + ((o_d + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
113
+ o_d_repeat = tl.arange(0, BD) // 2
114
+ p_x0 = x + o_t[:, None] * H*D + o_d[None, :]
115
+ p_x1 = x + o_t[:, None] * H*D + o_d_swap[None, :]
116
+ p_cos = cos + (o_cs[:, None] * R + o_d_repeat[None, :])
117
+ p_sin = sin + (o_cs[:, None] * R + o_d_repeat[None, :])
118
+ mask = (o_cs[:, None] >= 0) & (o_cs[:, None] < TR) & (o_d_repeat[None, :] < R)
119
+
120
+ b_cos = tl.load(p_cos, mask=mask, other=1.0).to(tl.float32)
121
+ b_sin = tl.load(p_sin, mask=mask, other=0.0).to(tl.float32)
122
+ b_x0 = tl.load(p_x0, mask=mask, other=0.0).to(tl.float32)
123
+ b_x1 = tl.load(p_x1, mask=mask, other=0.0).to(tl.float32)
124
+ if CONJUGATE:
125
+ b_sin = -b_sin
126
+ b_o0 = b_x0 * b_cos
127
+ b_o1 = b_x1 * b_sin
128
+ b_y = tl.where(o_d[None, :] % 2 == 0, b_o0 - b_o1, b_o0 + b_o1)
129
+ p_y = y + (o_t[:, None] * H*D + o_d[None, :])
130
+ tl.store(p_y, b_y, mask=mask)
131
+
132
+
133
+ def rotary_embedding_fwdbwd(
134
+ x: torch.Tensor,
135
+ cos: torch.Tensor,
136
+ sin: torch.Tensor,
137
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
138
+ cu_seqlens: Optional[torch.Tensor] = None,
139
+ max_seqlen: Optional[int] = None,
140
+ interleaved: bool = False,
141
+ inplace: bool = False,
142
+ conjugate: bool = False
143
+ ) -> torch.Tensor:
144
+ """
145
+ Args:
146
+ x: [B, T, H, D].
147
+ cos: [TR, R / 2]
148
+ sin: [TR, R / 2]
149
+ seqlen_offsets: integer or integer tensor of size (N,)
150
+ cu_seqlens: (N + 1,) or None
151
+ max_seqlen: int
152
+
153
+ Returns:
154
+ y: [B, T, H, D]
155
+ """
156
+ is_varlen = cu_seqlens is not None
157
+
158
+ B, T, H, D = x.shape
159
+ if not is_varlen:
160
+ N = B
161
+ else:
162
+ assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
163
+ N, T = cu_seqlens.shape[0] - 1, max_seqlen
164
+ TR, R = cos.shape
165
+ assert sin.shape == cos.shape
166
+ R2 = R * 2
167
+
168
+ assert D <= 256, "Only support D <= 256"
169
+ assert TR >= T, "TR must be >= T"
170
+
171
+ assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
172
+ assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
173
+
174
+ if isinstance(seqlen_offsets, torch.Tensor):
175
+ assert seqlen_offsets.shape == (N,)
176
+ assert seqlen_offsets.dtype in [torch.int32, torch.int64]
177
+ else:
178
+ assert seqlen_offsets + T <= TR
179
+
180
+ y = torch.empty_like(x) if not inplace else x
181
+ if R2 < D and not inplace:
182
+ y[..., R2:].copy_(x[..., R2:])
183
+
184
+ BD = triton.next_power_of_2(R2)
185
+ BT = min(128, triton.next_power_of_2(triton.cdiv(T, get_multiprocessor_count(x.device.index))))
186
+
187
+ def grid(meta): return (triton.cdiv(T, meta['BT']), N, H) # noqa
188
+ rotary_embedding_kernel[grid](
189
+ x,
190
+ cos,
191
+ sin,
192
+ y,
193
+ cu_seqlens,
194
+ seqlen_offsets,
195
+ B=B,
196
+ T=T,
197
+ H=H,
198
+ D=D,
199
+ R=R,
200
+ TR=TR,
201
+ BT=BT,
202
+ BD=BD,
203
+ IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor),
204
+ IS_VARLEN=is_varlen,
205
+ INTERLEAVED=interleaved,
206
+ CONJUGATE=conjugate
207
+ )
208
+ return y
209
+
210
+
211
+ class RotaryEmbeddingFunction(torch.autograd.Function):
212
+
213
+ @staticmethod
214
+ @input_guard
215
+ def forward(
216
+ ctx,
217
+ x,
218
+ cos,
219
+ sin,
220
+ interleaved=False,
221
+ inplace=False,
222
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
223
+ cu_seqlens: Optional[torch.Tensor] = None,
224
+ max_seqlen: Optional[int] = None,
225
+ ):
226
+ y = rotary_embedding_fwdbwd(
227
+ x,
228
+ cos,
229
+ sin,
230
+ seqlen_offsets=seqlen_offsets,
231
+ cu_seqlens=cu_seqlens,
232
+ max_seqlen=max_seqlen,
233
+ interleaved=interleaved,
234
+ inplace=inplace,
235
+ )
236
+ if isinstance(seqlen_offsets, int):
237
+ # Can't save int with save_for_backward
238
+ ctx.save_for_backward(cos, sin, cu_seqlens)
239
+ ctx.seqlen_offsets = seqlen_offsets
240
+ else:
241
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
242
+ ctx.seqlen_offsets = None
243
+ ctx.interleaved = interleaved
244
+ ctx.inplace = inplace
245
+ ctx.max_seqlen = max_seqlen
246
+ return y if not inplace else x
247
+
248
+ @staticmethod
249
+ @input_guard
250
+ def backward(ctx, do):
251
+ seqlen_offsets = ctx.seqlen_offsets
252
+ if seqlen_offsets is None:
253
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
254
+ else:
255
+ cos, sin, cu_seqlens = ctx.saved_tensors
256
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
257
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
258
+ if not ctx.interleaved and not ctx.inplace:
259
+ do = do.clone()
260
+ dx = rotary_embedding_fwdbwd(
261
+ do,
262
+ cos,
263
+ sin,
264
+ seqlen_offsets=seqlen_offsets,
265
+ cu_seqlens=cu_seqlens,
266
+ max_seqlen=ctx.max_seqlen,
267
+ interleaved=ctx.interleaved,
268
+ inplace=ctx.inplace,
269
+ conjugate=True,
270
+ )
271
+ return dx, None, None, None, None, None, None, None
272
+
273
+
274
+ def rotary_embedding(
275
+ x,
276
+ cos,
277
+ sin,
278
+ interleaved=False,
279
+ inplace=False,
280
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
281
+ cu_seqlens: Optional[torch.Tensor] = None,
282
+ max_seqlen: Optional[int] = None,
283
+ ):
284
+ """
285
+ Args:
286
+ x: [B, T, H, D]
287
+ cos, sin: [TR, R//2]
288
+ interleaved:
289
+ If True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style).
290
+ inplace:
291
+ If True, apply rotary embedding in-place.
292
+ seqlen_offsets: [N,] or int.
293
+ Each sequence in x is shifted by this amount.
294
+ Most commonly used in inference when we have KV cache.
295
+ cu_seqlens: [N + 1,] or None
296
+ max_seqlen: int
297
+
298
+ Returns:
299
+ out: [B, T, H, D]
300
+ """
301
+ return RotaryEmbeddingFunction.apply(
302
+ x,
303
+ cos,
304
+ sin,
305
+ interleaved,
306
+ inplace,
307
+ seqlen_offsets,
308
+ cu_seqlens,
309
+ max_seqlen
310
+ )
311
+
312
+
313
+ class RotaryEmbedding(nn.Module):
314
+ """
315
+ The rotary position embeddings from RoFormer_ (Su et. al).
316
+ A crucial insight from the method is that the query and keys are
317
+ transformed by rotation matrices which depend on the relative positions.
318
+
319
+ Other implementations are available in the Rotary Transformer repo_ and in
320
+ GPT-NeoX_, GPT-NeoX was an inspiration
321
+
322
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
323
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
324
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
325
+
326
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
327
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
328
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
329
+ """
330
+
331
+ def __init__(
332
+ self,
333
+ dim: int,
334
+ base: float = 10000.0,
335
+ scale_base: Optional[float] = None,
336
+ interleaved: bool = False,
337
+ pos_idx_in_fp32: bool = True,
338
+ device: Optional[torch.device] = None,
339
+ ):
340
+ """
341
+ interleaved:
342
+ If True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style).
343
+ pos_idx_in_fp32:
344
+ If True, the position indices [0.0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision.
345
+ This option was added because previously (before 2023-07-02), when we construct
346
+ the position indices, we use the dtype of self.inv_freq.
347
+ In most cases this would be fp32, but if the model is trained in pure bf16 (not mixed precision), then
348
+ self.inv_freq would be bf16, and the position indices are also in bf16.
349
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
350
+ embeddings for some positions will coincide.
351
+ To maintain compatibility with models previously trained in pure bf16, we add this option.
352
+ """
353
+ super().__init__()
354
+
355
+ self.dim = dim
356
+ self.base = float(base)
357
+ self.scale_base = scale_base
358
+ self.interleaved = interleaved
359
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
360
+ self.device = device
361
+
362
+ # Generate and save the inverse frequency buffer (non trainable)
363
+ self.register_buffer("inv_freq", torch.empty(-(dim // -2), dtype=torch.float32, device=device), persistent=False)
364
+
365
+ scale = None
366
+ if scale_base is not None:
367
+ scale = torch.empty(-(dim // -2), dtype=torch.float32, device=device)
368
+ self.register_buffer("scale", scale, persistent=False)
369
+
370
+ self._seq_len_cached = 0
371
+ self._cos_cached = None
372
+ self._sin_cached = None
373
+ self._cos_k_cached = None
374
+ self._sin_k_cached = None
375
+
376
+ self.reset_parameters()
377
+
378
+ def reset_parameters(self):
379
+ with torch.no_grad():
380
+ self.inv_freq.copy_(self._compute_inv_freq(device=self.inv_freq.device))
381
+ if self.scale_base is not None:
382
+ self.scale.copy_(self._compute_scale(device=self.scale.device))
383
+
384
+ def __repr__(self):
385
+ s = f"{self.__class__.__name__}("
386
+ s += f"dim={self.dim}, "
387
+ s += f"base={self.base}, "
388
+ s += f"interleaved={self.interleaved}, "
389
+ if self.scale_base is not None:
390
+ s += f"scale_base={self.scale_base}, "
391
+ s += f"pos_idx_in_fp32={self.pos_idx_in_fp32})"
392
+ return s
393
+
394
+ def _compute_inv_freq(self, device=None):
395
+ return 1.0 / (
396
+ self.base
397
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
398
+ )
399
+
400
+ def _compute_scale(self, device=None):
401
+ return (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) + 0.4 * self.dim) / (1.4 * self.dim)
402
+
403
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
404
+ # Reset the tables if the sequence length has changed,
405
+ # if we're on a new device (possibly due to tracing for instance),
406
+ # or if we're switching from inference mode to training
407
+ if (
408
+ seqlen > self._seq_len_cached
409
+ or self._cos_cached is None
410
+ or self._cos_cached.device != device
411
+ or self._cos_cached.dtype != dtype
412
+ or (self.training and self._cos_cached.is_inference())
413
+ ):
414
+ self._seq_len_cached = seqlen
415
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
416
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
417
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
418
+ if self.pos_idx_in_fp32:
419
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
420
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
421
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
422
+ # cos & sin output to change significantly.
423
+ # We want to recompute self.inv_freq if it was not loaded in fp32
424
+ if self.inv_freq.dtype != torch.float32:
425
+ inv_freq = self._compute_inv_freq(device=device)
426
+ else:
427
+ inv_freq = self.inv_freq
428
+ else:
429
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
430
+ inv_freq = self.inv_freq
431
+ # Don't do einsum, it converts fp32 to fp16 under AMP
432
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
433
+ freqs = torch.outer(t, inv_freq)
434
+ if self.scale is None:
435
+ self._cos_cached = torch.cos(freqs).to(dtype)
436
+ self._sin_cached = torch.sin(freqs).to(dtype)
437
+ else:
438
+ power = (
439
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
440
+ - seqlen // 2
441
+ ) / self.scale_base
442
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
443
+ # We want the multiplication by scale to happen in fp32
444
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
445
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
446
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
447
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
448
+
449
+ def forward(
450
+ self,
451
+ q: torch.Tensor,
452
+ k: torch.Tensor,
453
+ seqlen_offset: Union[int, torch.Tensor] = 0,
454
+ cu_seqlens: Optional[torch.Tensor] = None,
455
+ max_seqlen: Optional[int] = None,
456
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
457
+ """
458
+ q: [B, T, H, D]
459
+ k: [B, T, H, D]
460
+ seqlen_offset:
461
+ (N,) or int. Each sequence in x is shifted by this amount.
462
+ Most commonly used in inference when we have KV cache.
463
+ If it's a tensor of shape (N,), then to update the cos / sin cache, one
464
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
465
+ cu_seqlens: (N + 1,) or None
466
+ max_seqlen: int
467
+ """
468
+ if max_seqlen is not None:
469
+ self._update_cos_sin_cache(max_seqlen, device=q.device, dtype=q.dtype)
470
+ elif isinstance(seqlen_offset, int):
471
+ self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype)
472
+ if self.scale is None:
473
+ q = rotary_embedding(
474
+ q,
475
+ self._cos_cached,
476
+ self._sin_cached,
477
+ interleaved=self.interleaved,
478
+ seqlen_offsets=seqlen_offset,
479
+ cu_seqlens=cu_seqlens,
480
+ max_seqlen=max_seqlen
481
+ )
482
+ k = rotary_embedding(
483
+ k,
484
+ self._cos_cached,
485
+ self._sin_cached,
486
+ interleaved=self.interleaved,
487
+ seqlen_offsets=seqlen_offset,
488
+ cu_seqlens=cu_seqlens,
489
+ max_seqlen=max_seqlen
490
+ )
491
+
492
+ else:
493
+ q = rotary_embedding(
494
+ q,
495
+ self._cos_cached,
496
+ self._sin_cached,
497
+ interleaved=self.interleaved,
498
+ seqlen_offsets=seqlen_offset,
499
+ cu_seqlens=cu_seqlens,
500
+ max_seqlen=max_seqlen
501
+ )
502
+ k = rotary_embedding(
503
+ k,
504
+ self._cos_k_cached,
505
+ self._sin_k_cached,
506
+ interleaved=self.interleaved,
507
+ seqlen_offsets=seqlen_offset,
508
+ cu_seqlens=cu_seqlens,
509
+ max_seqlen=max_seqlen
510
+ )
511
+
512
+ return q, k
fla/ops/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .abc import chunk_abc
4
+ from .attn import parallel_attn, parallel_rectified_attn, parallel_softpick_attn, naive_attn, naive_rectified_attn, naive_softpick_attn
5
+ from .based import fused_chunk_based, parallel_based
6
+ from .delta_rule import chunk_delta_rule, fused_chunk_delta_rule, fused_recurrent_delta_rule
7
+ from .forgetting_attn import parallel_forgetting_attn
8
+ from .gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
9
+ from .generalized_delta_rule import (
10
+ chunk_dplr_delta_rule,
11
+ chunk_iplr_delta_rule,
12
+ fused_recurrent_dplr_delta_rule,
13
+ fused_recurrent_iplr_delta_rule
14
+ )
15
+ from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
16
+ from .gsa import chunk_gsa, fused_recurrent_gsa
17
+ from .hgrn import fused_recurrent_hgrn
18
+ from .lightning_attn import chunk_lightning_attn, fused_recurrent_lightning_attn
19
+ from .linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn
20
+ from .nsa import parallel_nsa
21
+ from .retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention
22
+ from .rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
23
+ from .rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7
24
+ from .simple_gla import chunk_simple_gla, fused_recurrent_simple_gla, parallel_simple_gla
25
+
26
+ __all__ = [
27
+ 'chunk_abc',
28
+ 'parallel_attn', 'parallel_rectified_attn', 'parallel_softpick_attn',
29
+ 'naive_attn', 'naive_rectified_attn', 'naive_softpick_attn',
30
+ 'fused_chunk_based', 'parallel_based',
31
+ 'chunk_delta_rule', 'fused_chunk_delta_rule', 'fused_recurrent_delta_rule',
32
+ 'parallel_forgetting_attn',
33
+ 'chunk_gated_delta_rule', 'fused_recurrent_gated_delta_rule',
34
+ 'chunk_dplr_delta_rule', 'chunk_iplr_delta_rule',
35
+ 'fused_recurrent_dplr_delta_rule', 'fused_recurrent_iplr_delta_rule',
36
+ 'chunk_gla', 'fused_chunk_gla', 'fused_recurrent_gla',
37
+ 'chunk_gsa', 'fused_recurrent_gsa',
38
+ 'fused_recurrent_hgrn',
39
+ 'chunk_lightning_attn', 'fused_recurrent_lightning_attn',
40
+ 'chunk_linear_attn', 'fused_chunk_linear_attn', 'fused_recurrent_linear_attn',
41
+ 'parallel_nsa',
42
+ 'chunk_retention', 'fused_chunk_retention', 'fused_recurrent_retention', 'parallel_retention',
43
+ 'chunk_rwkv6', 'fused_recurrent_rwkv6',
44
+ 'chunk_rwkv7', 'fused_recurrent_rwkv7',
45
+ 'chunk_simple_gla', 'fused_recurrent_simple_gla', 'parallel_simple_gla',
46
+ ]
fla/ops/abc/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_abc
4
+
5
+ __all__ = [
6
+ 'chunk_abc'
7
+ ]
fla/ops/attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (540 Bytes). View file
 
fla/ops/attn/__pycache__/naive_rectified.cpython-312.pyc ADDED
Binary file (2.24 kB). View file
 
fla/ops/attn/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (33.1 kB). View file
 
fla/ops/attn/naive.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ from einops import rearrange
4
+
5
+ def naive_attn(
6
+ q: torch.Tensor,
7
+ k: torch.Tensor,
8
+ v: torch.Tensor,
9
+ scale: Optional[float] = None,
10
+ cu_seqlens: Optional[torch.LongTensor] = None,
11
+ head_first: bool = False
12
+ ) -> torch.Tensor:
13
+ head_dim = q.shape[-1]
14
+ if scale is None:
15
+ scale = 1.0 / (head_dim ** 0.5)
16
+ if not head_first:
17
+ q, k, v = map(lambda x: rearrange(x, 'b t h d -> b h t d'), (q, k, v))
18
+ q_len = q.shape[-2]
19
+ k_len = k.shape[-2]
20
+ mask = torch.tril(torch.ones(k_len, k_len, device=q.device))
21
+ wei = torch.matmul(q, k.transpose(2, 3)) # shape: (batch_size, num_heads, q_len, k_len)
22
+ wei = wei * scale
23
+ wei = wei.masked_fill(mask[k_len-q_len:k_len, :k_len] == 0, float('-inf'))
24
+ wei = torch.softmax(wei.float(), dim=-1).to(q.dtype)
25
+ o = torch.matmul(wei, v) # shape: (batch_size, num_heads, q_len, head_dim)
26
+ if not head_first:
27
+ o = rearrange(o, 'b h t d -> b t h d')
28
+ return o, wei
fla/ops/attn/parallel.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils.op import exp, log
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
23
+ for num_stages in [2, 3, 4, 5]
24
+ ],
25
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
26
+ )
27
+ @triton.jit
28
+ def parallel_attn_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ o,
33
+ lse,
34
+ scale,
35
+ offsets,
36
+ indices,
37
+ T,
38
+ B: tl.constexpr,
39
+ H: tl.constexpr,
40
+ HQ: tl.constexpr,
41
+ G: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BS: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ BV: tl.constexpr,
48
+ USE_OFFSETS: tl.constexpr
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
52
+ i_h = i_hq // G
53
+
54
+ if USE_OFFSETS:
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ else:
59
+ i_n = i_b
60
+ bos, eos = i_n * T, i_n * T + T
61
+
62
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
63
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
64
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
65
+
66
+ # the Q block is kept in the shared memory throughout the whole kernel
67
+ # [BT, BK]
68
+ b_q = tl.load(p_q, boundary_check=(0, 1))
69
+ b_q = (b_q * scale).to(b_q.dtype)
70
+ # [BT, BV]
71
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
72
+
73
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
74
+ b_acc = tl.zeros([BT], dtype=tl.float32)
75
+ for i_s in range(0, i_t * BT, BS):
76
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
77
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
78
+ # [BK, BS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BT, BS]
83
+ b_s = tl.dot(b_q, b_k)
84
+
85
+ # [BT, BS]
86
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
87
+ b_r = exp(b_mp - b_m)
88
+ # [BT, BS]
89
+ b_p = exp(b_s - b_m[:, None])
90
+ # [BT]
91
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
92
+ # [BT, BV]
93
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
94
+
95
+ b_mp = b_m
96
+
97
+ # [BT]
98
+ o_q = i_t * BT + tl.arange(0, BT)
99
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
100
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
101
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
102
+
103
+ # [BS]
104
+ o_k = i_s + tl.arange(0, BS)
105
+ # [BK, BS]
106
+ b_k = tl.load(p_k, boundary_check=(0, 1))
107
+ # [BS, BV]
108
+ b_v = tl.load(p_v, boundary_check=(0, 1))
109
+ # [BT, BS]
110
+ b_s = tl.dot(b_q, b_k)
111
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
112
+
113
+ # [BT]
114
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
115
+ b_r = exp(b_mp - b_m)
116
+ # [BT, BS]
117
+ b_p = exp(b_s - b_m[:, None])
118
+ # [BT]
119
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
120
+ # [BT, BV]
121
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
122
+
123
+ b_mp = b_m
124
+ b_o = b_o / b_acc[:, None]
125
+ b_m += log(b_acc)
126
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
127
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
128
+
129
+
130
+ @triton.jit
131
+ def parallel_attn_bwd_kernel_preprocess(
132
+ o,
133
+ do,
134
+ delta,
135
+ B: tl.constexpr,
136
+ V: tl.constexpr
137
+ ):
138
+ i_n = tl.program_id(0)
139
+ o_d = tl.arange(0, B)
140
+ m_d = o_d < V
141
+
142
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
143
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
144
+ b_delta = tl.sum(b_o * b_do)
145
+
146
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
147
+
148
+
149
+ @triton.heuristics({
150
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
151
+ })
152
+ @triton.autotune(
153
+ configs=[
154
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
155
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
156
+ for num_stages in [2, 3, 4, 5]
157
+ ],
158
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
159
+ )
160
+ @triton.jit(do_not_specialize=['T'])
161
+ def parallel_attn_bwd_kernel_dq(
162
+ q,
163
+ k,
164
+ v,
165
+ lse,
166
+ delta,
167
+ do,
168
+ dq,
169
+ scale,
170
+ offsets,
171
+ indices,
172
+ T,
173
+ B: tl.constexpr,
174
+ H: tl.constexpr,
175
+ HQ: tl.constexpr,
176
+ G: tl.constexpr,
177
+ K: tl.constexpr,
178
+ V: tl.constexpr,
179
+ BT: tl.constexpr,
180
+ BS: tl.constexpr,
181
+ BK: tl.constexpr,
182
+ BV: tl.constexpr,
183
+ USE_OFFSETS: tl.constexpr
184
+ ):
185
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
186
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
187
+ i_h = i_hq // G
188
+
189
+ if USE_OFFSETS:
190
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
191
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
192
+ T = eos - bos
193
+ else:
194
+ i_n = i_b
195
+ bos, eos = i_n * T, i_n * T + T
196
+
197
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
198
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
199
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
200
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
201
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
202
+
203
+ # [BT, BK]
204
+ b_q = tl.load(p_q, boundary_check=(0, 1))
205
+ b_q = (b_q * scale).to(b_q.dtype)
206
+ # [BT, BV]
207
+ b_do = tl.load(p_do, boundary_check=(0, 1))
208
+ # [BT]
209
+ b_lse = tl.load(p_lse, boundary_check=(0,))
210
+ b_delta = tl.load(p_delta, boundary_check=(0,))
211
+
212
+ # [BT, BK]
213
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
214
+ for i_s in range(0, i_t * BT, BS):
215
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
216
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
217
+ # [BK, BS]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ # [BV, BS]
220
+ b_v = tl.load(p_v, boundary_check=(0, 1))
221
+
222
+ # [BT, BS]
223
+ b_s = tl.dot(b_q, b_k)
224
+ b_p = exp(b_s - b_lse[:, None])
225
+
226
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
227
+ b_dp = tl.dot(b_do, b_v)
228
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
229
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
230
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
231
+
232
+ # [BT]
233
+ o_q = i_t * BT + tl.arange(0, BT)
234
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
235
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
236
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
237
+ # [BS]
238
+ o_k = i_s + tl.arange(0, BS)
239
+ # [BK, BS]
240
+ b_k = tl.load(p_k, boundary_check=(0, 1))
241
+ # [BV, BS]
242
+ b_v = tl.load(p_v, boundary_check=(0, 1))
243
+
244
+ # [BT, BS]
245
+ b_s = tl.dot(b_q, b_k)
246
+ b_p = exp(b_s - b_lse[:, None])
247
+ b_p = tl.where(o_q[:, None] >= o_k[None, :], b_p, 0)
248
+
249
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
250
+ b_dp = tl.dot(b_do, b_v)
251
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
252
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
253
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
254
+
255
+ b_dq *= scale
256
+
257
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
258
+
259
+
260
+ @triton.heuristics({
261
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
262
+ })
263
+ @triton.autotune(
264
+ configs=[
265
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
266
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
267
+ for num_stages in [2, 3, 4, 5]
268
+ ],
269
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
270
+ )
271
+ @triton.jit(do_not_specialize=['T'])
272
+ def parallel_attn_bwd_kernel_dkv(
273
+ q,
274
+ k,
275
+ v,
276
+ lse,
277
+ delta,
278
+ do,
279
+ dk,
280
+ dv,
281
+ offsets,
282
+ indices,
283
+ scale,
284
+ T,
285
+ B: tl.constexpr,
286
+ H: tl.constexpr,
287
+ HQ: tl.constexpr,
288
+ G: tl.constexpr,
289
+ K: tl.constexpr,
290
+ V: tl.constexpr,
291
+ BT: tl.constexpr,
292
+ BS: tl.constexpr,
293
+ BK: tl.constexpr,
294
+ BV: tl.constexpr,
295
+ USE_OFFSETS: tl.constexpr
296
+ ):
297
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
298
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
299
+ i_h = i_hq // G
300
+
301
+ if USE_OFFSETS:
302
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
303
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
304
+ T = eos - bos
305
+ else:
306
+ i_n = i_b
307
+ bos, eos = i_n * T, i_n * T + T
308
+
309
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
310
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
311
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
312
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
313
+
314
+ # [BT, BK]
315
+ b_k = tl.load(p_k, boundary_check=(0, 1))
316
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
317
+ # [BT, BV]
318
+ b_v = tl.load(p_v, boundary_check=(0, 1))
319
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
320
+
321
+ o_k = i_t * BT + tl.arange(0, BT)
322
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
323
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
324
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
325
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
326
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
327
+
328
+ # [BS]
329
+ o_q = i_s + tl.arange(0, BS)
330
+ # [BS, BK]
331
+ b_q = tl.load(p_q, boundary_check=(0, 1))
332
+ b_q = (b_q * scale).to(b_q.dtype)
333
+ # [BS, BV]
334
+ b_do = tl.load(p_do, boundary_check=(0, 1))
335
+ # [BS]
336
+ b_lse = tl.load(p_lse, boundary_check=(0,))
337
+ b_delta = tl.load(p_delta, boundary_check=(0,))
338
+ # [BT, BS]
339
+ b_s = tl.dot(b_k, tl.trans(b_q))
340
+ b_p = exp(b_s - b_lse[None, :])
341
+ b_p = tl.where(o_k[:, None] <= o_q[None, :], b_p, 0)
342
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
343
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
344
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
345
+ b_dp = tl.dot(b_v, tl.trans(b_do))
346
+ # [BT, BS]
347
+ b_ds = b_p * (b_dp - b_delta[None, :])
348
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
349
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
350
+
351
+ for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS):
352
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
353
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
354
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
355
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
356
+
357
+ # [BS]
358
+ o_q = i_s + tl.arange(0, BS)
359
+ # [BS, BK]
360
+ b_q = tl.load(p_q, boundary_check=(0, 1))
361
+ b_q = (b_q * scale).to(b_q.dtype)
362
+ # [BS, BV]
363
+ b_do = tl.load(p_do, boundary_check=(0, 1))
364
+ # [BS]
365
+ b_lse = tl.load(p_lse, boundary_check=(0,))
366
+ b_delta = tl.load(p_delta, boundary_check=(0,))
367
+ # [BT, BS]
368
+ b_s = tl.dot(b_k, tl.trans(b_q))
369
+ b_p = exp(b_s - b_lse[None, :])
370
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
371
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
372
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
373
+ b_dp = tl.dot(b_v, tl.trans(b_do))
374
+ # [BT, BS]
375
+ b_ds = b_p * (b_dp - b_delta[None, :])
376
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
377
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
378
+
379
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
380
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
381
+
382
+
383
+ def parallel_attn_fwd(
384
+ q: torch.Tensor,
385
+ k: torch.Tensor,
386
+ v: torch.Tensor,
387
+ scale: float,
388
+ chunk_size: int = 128,
389
+ offsets: Optional[torch.LongTensor] = None,
390
+ indices: Optional[torch.LongTensor] = None,
391
+ ):
392
+ B, T, H, K, V = *k.shape, v.shape[-1]
393
+ HQ = q.shape[2]
394
+ G = HQ // H
395
+ BT = chunk_size
396
+ if check_shared_mem('hopper', q.device.index):
397
+ BS = min(64, max(16, triton.next_power_of_2(T)))
398
+ BK = min(256, max(16, triton.next_power_of_2(K)))
399
+ BV = min(256, max(16, triton.next_power_of_2(V)))
400
+ elif check_shared_mem('ampere', q.device.index):
401
+ BS = min(32, max(16, triton.next_power_of_2(T)))
402
+ BK = min(256, max(16, triton.next_power_of_2(K)))
403
+ BV = min(128, max(16, triton.next_power_of_2(V)))
404
+ else:
405
+ BS = min(32, max(16, triton.next_power_of_2(T)))
406
+ BK = min(256, max(16, triton.next_power_of_2(K)))
407
+ BV = min(64, max(16, triton.next_power_of_2(V)))
408
+ NK = triton.cdiv(K, BK)
409
+ NV = triton.cdiv(V, BV)
410
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
411
+ assert NK == 1, "The key dimension can not be larger than 256"
412
+
413
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
414
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
415
+
416
+ grid = (NV, NT, B * HQ)
417
+ parallel_attn_fwd_kernel[grid](
418
+ q=q,
419
+ k=k,
420
+ v=v,
421
+ o=o,
422
+ lse=lse,
423
+ scale=scale,
424
+ offsets=offsets,
425
+ indices=indices,
426
+ B=B,
427
+ T=T,
428
+ H=H,
429
+ HQ=HQ,
430
+ G=G,
431
+ K=K,
432
+ V=V,
433
+ BT=BT,
434
+ BS=BS,
435
+ BK=BK,
436
+ BV=BV,
437
+ )
438
+ return o, lse
439
+
440
+
441
+ def parallel_attn_bwd_preprocess(
442
+ o: torch.Tensor,
443
+ do: torch.Tensor
444
+ ):
445
+ V = o.shape[-1]
446
+ delta = torch.empty_like(o[..., 0], dtype=torch.float32)
447
+ parallel_attn_bwd_kernel_preprocess[(delta.numel(),)](
448
+ o=o,
449
+ do=do,
450
+ delta=delta,
451
+ B=triton.next_power_of_2(V),
452
+ V=V,
453
+ )
454
+ return delta
455
+
456
+
457
+ def parallel_attn_bwd(
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ o: torch.Tensor,
462
+ lse: torch.Tensor,
463
+ do: torch.Tensor,
464
+ scale: float = None,
465
+ chunk_size: int = 128,
466
+ offsets: Optional[torch.LongTensor] = None,
467
+ indices: Optional[torch.LongTensor] = None,
468
+ ):
469
+ B, T, H, K, V = *k.shape, v.shape[-1]
470
+ HQ = q.shape[2]
471
+ G = HQ // H
472
+ BT = chunk_size
473
+ BS = max(16, triton.next_power_of_2(T))
474
+ BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS)
475
+ BK = max(16, triton.next_power_of_2(K))
476
+ BV = max(16, triton.next_power_of_2(V))
477
+ NV = triton.cdiv(V, BV)
478
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
479
+
480
+ delta = parallel_attn_bwd_preprocess(o, do)
481
+
482
+ dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
483
+ dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
484
+ dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device)
485
+ grid = (NV, NT, B * HQ)
486
+ parallel_attn_bwd_kernel_dq[grid](
487
+ q=q,
488
+ k=k,
489
+ v=v,
490
+ lse=lse,
491
+ delta=delta,
492
+ do=do,
493
+ dq=dq,
494
+ offsets=offsets,
495
+ indices=indices,
496
+ scale=scale,
497
+ T=T,
498
+ B=B,
499
+ H=H,
500
+ HQ=HQ,
501
+ G=G,
502
+ K=K,
503
+ V=V,
504
+ BT=BT,
505
+ BS=BS,
506
+ BK=BK,
507
+ BV=BV
508
+ )
509
+ parallel_attn_bwd_kernel_dkv[grid](
510
+ q=q,
511
+ k=k,
512
+ v=v,
513
+ lse=lse,
514
+ delta=delta,
515
+ do=do,
516
+ dk=dk,
517
+ dv=dv,
518
+ offsets=offsets,
519
+ indices=indices,
520
+ scale=scale,
521
+ T=T,
522
+ B=B,
523
+ H=H,
524
+ HQ=HQ,
525
+ G=G,
526
+ K=K,
527
+ V=V,
528
+ BT=BT,
529
+ BS=BS,
530
+ BK=BK,
531
+ BV=BV
532
+ )
533
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
534
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
535
+ return dq, dk, dv
536
+
537
+
538
+ @torch.compile
539
+ class ParallelAttentionFunction(torch.autograd.Function):
540
+
541
+ @staticmethod
542
+ @contiguous
543
+ @autocast_custom_fwd
544
+ def forward(ctx, q, k, v, scale, offsets):
545
+ ctx.dtype = q.dtype
546
+
547
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
548
+ # 2-d indices denoting the offsets of chunks in each sequence
549
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
550
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
551
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
552
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
553
+
554
+ o, lse = parallel_attn_fwd(
555
+ q=q,
556
+ k=k,
557
+ v=v,
558
+ scale=scale,
559
+ chunk_size=chunk_size,
560
+ offsets=offsets,
561
+ indices=indices
562
+ )
563
+ ctx.save_for_backward(q, k, v, o, lse)
564
+ ctx.chunk_size = chunk_size
565
+ ctx.offsets = offsets
566
+ ctx.indices = indices
567
+ ctx.scale = scale
568
+ return o.to(q.dtype)
569
+
570
+ @staticmethod
571
+ @contiguous
572
+ @autocast_custom_bwd
573
+ def backward(ctx, do):
574
+ q, k, v, o, lse = ctx.saved_tensors
575
+ dq, dk, dv = parallel_attn_bwd(
576
+ q=q,
577
+ k=k,
578
+ v=v,
579
+ o=o,
580
+ lse=lse,
581
+ do=do,
582
+ scale=ctx.scale,
583
+ chunk_size=ctx.chunk_size,
584
+ offsets=ctx.offsets,
585
+ indices=ctx.indices
586
+ )
587
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
588
+
589
+
590
+ def parallel_attn(
591
+ q: torch.Tensor,
592
+ k: torch.Tensor,
593
+ v: torch.Tensor,
594
+ scale: Optional[float] = None,
595
+ cu_seqlens: Optional[torch.LongTensor] = None,
596
+ head_first: bool = False
597
+ ) -> torch.Tensor:
598
+ r"""
599
+ Args:
600
+ q (torch.Tensor):
601
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
602
+ k (torch.Tensor):
603
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
604
+ GQA will be applied if HQ is divisible by H.
605
+ v (torch.Tensor):
606
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
607
+ scale (Optional[int]):
608
+ Scale factor for attention scores.
609
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
610
+ cu_seqlens (torch.LongTensor):
611
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
612
+ consistent with the FlashAttention API.
613
+ head_first (Optional[bool]):
614
+ Whether the inputs are in the head-first format. Default: `False`.
615
+
616
+ Returns:
617
+ o (torch.Tensor):
618
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
619
+ """
620
+ if scale is None:
621
+ scale = k.shape[-1] ** -0.5
622
+ if cu_seqlens is not None:
623
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
624
+ if head_first:
625
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
626
+ o = ParallelAttentionFunction.apply(q, k, v, scale, cu_seqlens)
627
+ if head_first:
628
+ o = rearrange(o, 'b t h d -> b h t d')
629
+ return o
fla/ops/attn/parallel_rectified.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils.op import exp, log
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
23
+ for num_stages in [2, 3, 4, 5]
24
+ ],
25
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
26
+ )
27
+ @triton.jit
28
+ def parallel_rect_attn_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ o,
33
+ lse,
34
+ scale,
35
+ offsets,
36
+ indices,
37
+ T,
38
+ B: tl.constexpr,
39
+ H: tl.constexpr,
40
+ HQ: tl.constexpr,
41
+ G: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BS: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ BV: tl.constexpr,
48
+ USE_OFFSETS: tl.constexpr
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
52
+ i_h = i_hq // G
53
+
54
+ if USE_OFFSETS:
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ else:
59
+ i_n = i_b
60
+ bos, eos = i_n * T, i_n * T + T
61
+
62
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
63
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
64
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
65
+
66
+ # the Q block is kept in the shared memory throughout the whole kernel
67
+ # [BT, BK]
68
+ b_q = tl.load(p_q, boundary_check=(0, 1))
69
+ b_q = (b_q * scale).to(b_q.dtype)
70
+ # [BT, BV]
71
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
72
+
73
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
74
+ b_acc = tl.zeros([BT], dtype=tl.float32)
75
+ for i_s in range(0, i_t * BT, BS):
76
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
77
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
78
+ # [BK, BS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BT, BS]
83
+ b_s = tl.dot(b_q, b_k)
84
+
85
+ # [BT, BS]
86
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
87
+ b_r = exp(b_mp - b_m)
88
+ # [BT, BS]
89
+ # b_p = exp(b_s - b_m[:, None])
90
+ # b_p = exp(tl.where(b_s < 0, float('-inf'), b_s - b_m[:, None]))
91
+ b_p = tl.where(b_s < 0, 0.0, exp(b_s - b_m[:, None])) # Just do this
92
+ # [BT]
93
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
94
+ # [BT, BV]
95
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
96
+
97
+ b_mp = b_m
98
+
99
+ # [BT]
100
+ o_q = i_t * BT + tl.arange(0, BT)
101
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
102
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
103
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
104
+
105
+ # [BS]
106
+ o_k = i_s + tl.arange(0, BS)
107
+ # [BK, BS]
108
+ b_k = tl.load(p_k, boundary_check=(0, 1))
109
+ # [BS, BV]
110
+ b_v = tl.load(p_v, boundary_check=(0, 1))
111
+ # [BT, BS]
112
+ b_s = tl.dot(b_q, b_k)
113
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
114
+
115
+ # [BT]
116
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
117
+ b_r = exp(b_mp - b_m)
118
+ # [BT, BS]
119
+ # b_p = exp(b_s - b_m[:, None])
120
+ # b_p = exp(tl.where(b_s < 0, float('-inf'), b_s - b_m[:, None]))
121
+ b_p = tl.where(b_s < 0, 0.0, exp(b_s - b_m[:, None]))
122
+ # [BT]
123
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
124
+ # [BT, BV]
125
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
126
+
127
+ b_mp = b_m
128
+ # b_o = b_o / b_acc[:, None]
129
+ b_o = tl.where(b_acc[:, None] == 0, 0.0, b_o / b_acc[:, None])
130
+ # b_m += tl.log(b_acc)
131
+ b_m = tl.where(b_acc == 0, 0.0, b_m + tl.log(b_acc))
132
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
133
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
134
+
135
+
136
+ @triton.jit
137
+ def parallel_rect_attn_bwd_kernel_preprocess(
138
+ o,
139
+ do,
140
+ delta,
141
+ B: tl.constexpr,
142
+ V: tl.constexpr
143
+ ):
144
+ i_n = tl.program_id(0)
145
+ o_d = tl.arange(0, B)
146
+ m_d = o_d < V
147
+
148
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
149
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
150
+ b_delta = tl.sum(b_o * b_do)
151
+
152
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
153
+
154
+
155
+ @triton.heuristics({
156
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
157
+ })
158
+ @triton.autotune(
159
+ configs=[
160
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
161
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
162
+ for num_stages in [2, 3, 4, 5]
163
+ ],
164
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
165
+ )
166
+ @triton.jit(do_not_specialize=['T'])
167
+ def parallel_rect_attn_bwd_kernel_dq(
168
+ q,
169
+ k,
170
+ v,
171
+ lse,
172
+ delta,
173
+ do,
174
+ dq,
175
+ scale,
176
+ offsets,
177
+ indices,
178
+ T,
179
+ B: tl.constexpr,
180
+ H: tl.constexpr,
181
+ HQ: tl.constexpr,
182
+ G: tl.constexpr,
183
+ K: tl.constexpr,
184
+ V: tl.constexpr,
185
+ BT: tl.constexpr,
186
+ BS: tl.constexpr,
187
+ BK: tl.constexpr,
188
+ BV: tl.constexpr,
189
+ USE_OFFSETS: tl.constexpr
190
+ ):
191
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
192
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
193
+ i_h = i_hq // G
194
+
195
+ if USE_OFFSETS:
196
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
197
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
198
+ T = eos - bos
199
+ else:
200
+ i_n = i_b
201
+ bos, eos = i_n * T, i_n * T + T
202
+
203
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
204
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
205
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
206
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
207
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
208
+
209
+ # [BT, BK]
210
+ b_q = tl.load(p_q, boundary_check=(0, 1))
211
+ b_q = (b_q * scale).to(b_q.dtype)
212
+ # [BT, BV]
213
+ b_do = tl.load(p_do, boundary_check=(0, 1))
214
+ # [BT]
215
+ b_lse = tl.load(p_lse, boundary_check=(0,))
216
+ b_delta = tl.load(p_delta, boundary_check=(0,))
217
+
218
+ # [BT, BK]
219
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
220
+ for i_s in range(0, i_t * BT, BS):
221
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
222
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
223
+ # [BK, BS]
224
+ b_k = tl.load(p_k, boundary_check=(0, 1))
225
+ # [BV, BS]
226
+ b_v = tl.load(p_v, boundary_check=(0, 1))
227
+
228
+ # [BT, BS]
229
+ b_s = tl.dot(b_q, b_k)
230
+ # b_p = exp(b_s - b_lse[:, None])
231
+ # b_p = exp(tl.where(b_s < 0, float('-inf'), b_s - b_lse[:, None]))
232
+ b_p = tl.where(b_s < 0, 0.0, exp(b_s - b_lse[:, None]))
233
+
234
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
235
+ b_dp = tl.dot(b_do, b_v)
236
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
237
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
238
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
239
+
240
+ # [BT]
241
+ o_q = i_t * BT + tl.arange(0, BT)
242
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
243
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
244
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
245
+ # [BS]
246
+ o_k = i_s + tl.arange(0, BS)
247
+ # [BK, BS]
248
+ b_k = tl.load(p_k, boundary_check=(0, 1))
249
+ # [BV, BS]
250
+ b_v = tl.load(p_v, boundary_check=(0, 1))
251
+
252
+ # [BT, BS]
253
+ b_s = tl.dot(b_q, b_k)
254
+ # b_p = exp(b_s - b_lse[:, None])
255
+ # b_p = exp(tl.where(b_s < 0, float('-inf'), b_s - b_lse[:, None]))
256
+ b_p = tl.where(b_s < 0, 0.0, exp(b_s - b_lse[:, None]))
257
+ b_p = tl.where(o_q[:, None] >= o_k[None, :], b_p, 0)
258
+
259
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
260
+ b_dp = tl.dot(b_do, b_v)
261
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
262
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
263
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
264
+
265
+ b_dq *= scale
266
+
267
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
268
+
269
+
270
+ @triton.heuristics({
271
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
272
+ })
273
+ @triton.autotune(
274
+ configs=[
275
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
276
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
277
+ for num_stages in [2, 3, 4, 5]
278
+ ],
279
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
280
+ )
281
+ @triton.jit(do_not_specialize=['T'])
282
+ def parallel_rect_attn_bwd_kernel_dkv(
283
+ q,
284
+ k,
285
+ v,
286
+ lse,
287
+ delta,
288
+ do,
289
+ dk,
290
+ dv,
291
+ offsets,
292
+ indices,
293
+ scale,
294
+ T,
295
+ B: tl.constexpr,
296
+ H: tl.constexpr,
297
+ HQ: tl.constexpr,
298
+ G: tl.constexpr,
299
+ K: tl.constexpr,
300
+ V: tl.constexpr,
301
+ BT: tl.constexpr,
302
+ BS: tl.constexpr,
303
+ BK: tl.constexpr,
304
+ BV: tl.constexpr,
305
+ USE_OFFSETS: tl.constexpr
306
+ ):
307
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
308
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
309
+ i_h = i_hq // G
310
+
311
+ if USE_OFFSETS:
312
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
313
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
314
+ T = eos - bos
315
+ else:
316
+ i_n = i_b
317
+ bos, eos = i_n * T, i_n * T + T
318
+
319
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
320
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
321
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
322
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
323
+
324
+ # [BT, BK]
325
+ b_k = tl.load(p_k, boundary_check=(0, 1))
326
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
327
+ # [BT, BV]
328
+ b_v = tl.load(p_v, boundary_check=(0, 1))
329
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
330
+
331
+ o_k = i_t * BT + tl.arange(0, BT)
332
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
333
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
334
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
335
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
336
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
337
+
338
+ # [BS]
339
+ o_q = i_s + tl.arange(0, BS)
340
+ # [BS, BK]
341
+ b_q = tl.load(p_q, boundary_check=(0, 1))
342
+ b_q = (b_q * scale).to(b_q.dtype)
343
+ # [BS, BV]
344
+ b_do = tl.load(p_do, boundary_check=(0, 1))
345
+ # [BS]
346
+ b_lse = tl.load(p_lse, boundary_check=(0,))
347
+ b_delta = tl.load(p_delta, boundary_check=(0,))
348
+ # [BT, BS]
349
+ b_s = tl.dot(b_k, tl.trans(b_q))
350
+ # b_p = exp(b_s - b_lse[None, :])
351
+ # b_p = exp(tl.where(b_s < 0, float('-inf'), b_s - b_lse[None, :]))
352
+ b_p = tl.where(b_s < 0, 0.0, exp(b_s - b_lse[None, :]))
353
+ b_p = tl.where(o_k[:, None] <= o_q[None, :], b_p, 0)
354
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
355
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
356
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
357
+ b_dp = tl.dot(b_v, tl.trans(b_do))
358
+ # [BT, BS]
359
+ b_ds = b_p * (b_dp - b_delta[None, :])
360
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
361
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
362
+
363
+ for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS):
364
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
365
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
366
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
367
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
368
+
369
+ # [BS]
370
+ o_q = i_s + tl.arange(0, BS)
371
+ # [BS, BK]
372
+ b_q = tl.load(p_q, boundary_check=(0, 1))
373
+ b_q = (b_q * scale).to(b_q.dtype)
374
+ # [BS, BV]
375
+ b_do = tl.load(p_do, boundary_check=(0, 1))
376
+ # [BS]
377
+ b_lse = tl.load(p_lse, boundary_check=(0,))
378
+ b_delta = tl.load(p_delta, boundary_check=(0,))
379
+ # [BT, BS]
380
+ b_s = tl.dot(b_k, tl.trans(b_q))
381
+ # b_p = exp(b_s - b_lse[None, :])
382
+ # b_p = exp(tl.where(b_s < 0, float('-inf'), b_s - b_lse[None, :]))
383
+ b_p = tl.where(b_s < 0, 0.0, exp(b_s - b_lse[None, :]))
384
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
385
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
386
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
387
+ b_dp = tl.dot(b_v, tl.trans(b_do))
388
+ # [BT, BS]
389
+ b_ds = b_p * (b_dp - b_delta[None, :])
390
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
391
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
392
+
393
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
394
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
395
+
396
+
397
+ def parallel_rect_attn_fwd(
398
+ q: torch.Tensor,
399
+ k: torch.Tensor,
400
+ v: torch.Tensor,
401
+ scale: float,
402
+ chunk_size: int = 128,
403
+ offsets: Optional[torch.LongTensor] = None,
404
+ indices: Optional[torch.LongTensor] = None,
405
+ ):
406
+ B, T, H, K, V = *k.shape, v.shape[-1]
407
+ HQ = q.shape[2]
408
+ G = HQ // H
409
+ BT = chunk_size
410
+ if check_shared_mem('hopper', q.device.index):
411
+ BS = min(64, max(16, triton.next_power_of_2(T)))
412
+ BK = min(256, max(16, triton.next_power_of_2(K)))
413
+ BV = min(256, max(16, triton.next_power_of_2(V)))
414
+ elif check_shared_mem('ampere', q.device.index):
415
+ BS = min(32, max(16, triton.next_power_of_2(T)))
416
+ BK = min(256, max(16, triton.next_power_of_2(K)))
417
+ BV = min(128, max(16, triton.next_power_of_2(V)))
418
+ else:
419
+ BS = min(32, max(16, triton.next_power_of_2(T)))
420
+ BK = min(256, max(16, triton.next_power_of_2(K)))
421
+ BV = min(64, max(16, triton.next_power_of_2(V)))
422
+ NK = triton.cdiv(K, BK)
423
+ NV = triton.cdiv(V, BV)
424
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
425
+ assert NK == 1, "The key dimension can not be larger than 256"
426
+
427
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
428
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
429
+
430
+ grid = (NV, NT, B * HQ)
431
+ parallel_rect_attn_fwd_kernel[grid](
432
+ q=q,
433
+ k=k,
434
+ v=v,
435
+ o=o,
436
+ lse=lse,
437
+ scale=scale,
438
+ offsets=offsets,
439
+ indices=indices,
440
+ B=B,
441
+ T=T,
442
+ H=H,
443
+ HQ=HQ,
444
+ G=G,
445
+ K=K,
446
+ V=V,
447
+ BT=BT,
448
+ BS=BS,
449
+ BK=BK,
450
+ BV=BV,
451
+ )
452
+ return o, lse
453
+
454
+
455
+ def parallel_rect_attn_bwd_preprocess(
456
+ o: torch.Tensor,
457
+ do: torch.Tensor
458
+ ):
459
+ V = o.shape[-1]
460
+ delta = torch.empty_like(o[..., 0], dtype=torch.float32)
461
+ parallel_rect_attn_bwd_kernel_preprocess[(delta.numel(),)](
462
+ o=o,
463
+ do=do,
464
+ delta=delta,
465
+ B=triton.next_power_of_2(V),
466
+ V=V,
467
+ )
468
+ return delta
469
+
470
+
471
+ def parallel_rect_attn_bwd(
472
+ q: torch.Tensor,
473
+ k: torch.Tensor,
474
+ v: torch.Tensor,
475
+ o: torch.Tensor,
476
+ lse: torch.Tensor,
477
+ do: torch.Tensor,
478
+ scale: float = None,
479
+ chunk_size: int = 128,
480
+ offsets: Optional[torch.LongTensor] = None,
481
+ indices: Optional[torch.LongTensor] = None,
482
+ ):
483
+ B, T, H, K, V = *k.shape, v.shape[-1]
484
+ HQ = q.shape[2]
485
+ G = HQ // H
486
+ BT = chunk_size
487
+ BS = max(16, triton.next_power_of_2(T))
488
+ BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS)
489
+ BK = max(16, triton.next_power_of_2(K))
490
+ BV = max(16, triton.next_power_of_2(V))
491
+ NV = triton.cdiv(V, BV)
492
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
493
+
494
+ delta = parallel_rect_attn_bwd_preprocess(o, do)
495
+
496
+ dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
497
+ dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
498
+ dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device)
499
+ grid = (NV, NT, B * HQ)
500
+ parallel_rect_attn_bwd_kernel_dq[grid](
501
+ q=q,
502
+ k=k,
503
+ v=v,
504
+ lse=lse,
505
+ delta=delta,
506
+ do=do,
507
+ dq=dq,
508
+ offsets=offsets,
509
+ indices=indices,
510
+ scale=scale,
511
+ T=T,
512
+ B=B,
513
+ H=H,
514
+ HQ=HQ,
515
+ G=G,
516
+ K=K,
517
+ V=V,
518
+ BT=BT,
519
+ BS=BS,
520
+ BK=BK,
521
+ BV=BV
522
+ )
523
+ parallel_rect_attn_bwd_kernel_dkv[grid](
524
+ q=q,
525
+ k=k,
526
+ v=v,
527
+ lse=lse,
528
+ delta=delta,
529
+ do=do,
530
+ dk=dk,
531
+ dv=dv,
532
+ offsets=offsets,
533
+ indices=indices,
534
+ scale=scale,
535
+ T=T,
536
+ B=B,
537
+ H=H,
538
+ HQ=HQ,
539
+ G=G,
540
+ K=K,
541
+ V=V,
542
+ BT=BT,
543
+ BS=BS,
544
+ BK=BK,
545
+ BV=BV
546
+ )
547
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
548
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
549
+ return dq, dk, dv
550
+
551
+
552
+ @torch.compile
553
+ class ParallelRectifiedAttentionFunction(torch.autograd.Function):
554
+
555
+ @staticmethod
556
+ @contiguous
557
+ @autocast_custom_fwd
558
+ def forward(ctx, q, k, v, scale, offsets):
559
+ ctx.dtype = q.dtype
560
+
561
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
562
+ # 2-d indices denoting the offsets of chunks in each sequence
563
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
564
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
565
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
566
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
567
+
568
+ o, lse = parallel_rect_attn_fwd(
569
+ q=q,
570
+ k=k,
571
+ v=v,
572
+ scale=scale,
573
+ chunk_size=chunk_size,
574
+ offsets=offsets,
575
+ indices=indices
576
+ )
577
+ ctx.save_for_backward(q, k, v, o, lse)
578
+ ctx.chunk_size = chunk_size
579
+ ctx.offsets = offsets
580
+ ctx.indices = indices
581
+ ctx.scale = scale
582
+ return o.to(q.dtype)
583
+
584
+ @staticmethod
585
+ @contiguous
586
+ @autocast_custom_bwd
587
+ def backward(ctx, do):
588
+ q, k, v, o, lse = ctx.saved_tensors
589
+ dq, dk, dv = parallel_rect_attn_bwd(
590
+ q=q,
591
+ k=k,
592
+ v=v,
593
+ o=o,
594
+ lse=lse,
595
+ do=do,
596
+ scale=ctx.scale,
597
+ chunk_size=ctx.chunk_size,
598
+ offsets=ctx.offsets,
599
+ indices=ctx.indices
600
+ )
601
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
602
+
603
+
604
+ def parallel_rectified_attn(
605
+ q: torch.Tensor,
606
+ k: torch.Tensor,
607
+ v: torch.Tensor,
608
+ scale: Optional[float] = None,
609
+ cu_seqlens: Optional[torch.LongTensor] = None,
610
+ head_first: bool = False
611
+ ) -> torch.Tensor:
612
+ r"""
613
+ Args:
614
+ q (torch.Tensor):
615
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
616
+ k (torch.Tensor):
617
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
618
+ GQA will be applied if HQ is divisible by H.
619
+ v (torch.Tensor):
620
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
621
+ scale (Optional[int]):
622
+ Scale factor for attention scores.
623
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
624
+ cu_seqlens (torch.LongTensor):
625
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
626
+ consistent with the FlashAttention API.
627
+ head_first (Optional[bool]):
628
+ Whether the inputs are in the head-first format. Default: `False`.
629
+
630
+ Returns:
631
+ o (torch.Tensor):
632
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
633
+ """
634
+ if scale is None:
635
+ scale = k.shape[-1] ** -0.5
636
+ if cu_seqlens is not None:
637
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
638
+ if head_first:
639
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
640
+ o = ParallelRectifiedAttentionFunction.apply(q, k, v, scale, cu_seqlens)
641
+ if head_first:
642
+ o = rearrange(o, 'b t h d -> b h t d')
643
+ return o
fla/ops/attn/parallel_softpick.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils.op import exp, log
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
23
+ for num_stages in [2, 3, 4, 5]
24
+ ],
25
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
26
+ )
27
+ @triton.jit
28
+ def parallel_softpick_attn_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ o,
33
+ lse,
34
+ scale,
35
+ offsets,
36
+ indices,
37
+ T,
38
+ B: tl.constexpr,
39
+ H: tl.constexpr,
40
+ HQ: tl.constexpr,
41
+ G: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BS: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ BV: tl.constexpr,
48
+ USE_OFFSETS: tl.constexpr
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
52
+ i_h = i_hq // G
53
+
54
+ if USE_OFFSETS:
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ else:
59
+ i_n = i_b
60
+ bos, eos = i_n * T, i_n * T + T
61
+
62
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
63
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
64
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
65
+
66
+ # the Q block is kept in the shared memory throughout the whole kernel
67
+ # [BT, BK]
68
+ b_q = tl.load(p_q, boundary_check=(0, 1))
69
+ b_q = (b_q * scale).to(b_q.dtype)
70
+ # [BT, BV]
71
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
72
+
73
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
74
+ b_acc = tl.zeros([BT], dtype=tl.float32)
75
+ for i_s in range(0, i_t * BT, BS):
76
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
77
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
78
+ # [BK, BS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BT, BS]
83
+ b_s = tl.dot(b_q, b_k)
84
+
85
+ # [BT, BS]
86
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
87
+ b_r = exp(b_mp - b_m)
88
+ # [BT, BS]
89
+ b_p = exp(b_s - b_m[:, None]) - exp(-b_m[:, None])
90
+ b_p_r = tl.maximum(b_p, 0.0)
91
+ b_p_a = tl.abs(b_p)
92
+ # [BT]
93
+ b_acc = b_acc * b_r + tl.sum(b_p_a, 1)
94
+ # [BT, BV]
95
+ b_o = b_o * b_r[:, None] + tl.dot(b_p_r.to(b_q.dtype), b_v)
96
+
97
+ b_mp = b_m
98
+
99
+ # [BT]
100
+ o_q = i_t * BT + tl.arange(0, BT)
101
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
102
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
103
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
104
+
105
+ # [BS]
106
+ o_k = i_s + tl.arange(0, BS)
107
+ # [BK, BS]
108
+ b_k = tl.load(p_k, boundary_check=(0, 1))
109
+ # [BS, BV]
110
+ b_v = tl.load(p_v, boundary_check=(0, 1))
111
+ # [BT, BS]
112
+ b_s = tl.dot(b_q, b_k)
113
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
114
+
115
+ # [BT]
116
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
117
+ b_r = exp(b_mp - b_m)
118
+ # [BT, BS]
119
+ b_p = exp(b_s - b_m[:, None]) - exp(-b_m[:, None])
120
+ b_p_r = tl.maximum(b_p, 0.0)
121
+ b_p_a = tl.abs(b_p)
122
+ b_p_a = tl.where(o_q[:, None] >= o_k[None, :], b_p_a, 0)
123
+ # [BT]
124
+ b_acc = b_acc * b_r + tl.sum(b_p_a, 1)
125
+ # [BT, BV]
126
+ b_o = b_o * b_r[:, None] + tl.dot(b_p_r.to(b_q.dtype), b_v)
127
+
128
+ b_mp = b_m
129
+ b_acc += 1e-6 # harcoded epsilon... sorry
130
+ b_o = b_o / b_acc[:, None]
131
+ b_m += log(b_acc)
132
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
133
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
134
+
135
+
136
+ @triton.jit
137
+ def parallel_softpick_attn_bwd_kernel_preprocess(
138
+ o,
139
+ do,
140
+ delta,
141
+ B: tl.constexpr,
142
+ V: tl.constexpr
143
+ ):
144
+ i_n = tl.program_id(0)
145
+ o_d = tl.arange(0, B)
146
+ m_d = o_d < V
147
+
148
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
149
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
150
+ b_delta = tl.sum(b_o * b_do)
151
+
152
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
153
+
154
+
155
+ @triton.heuristics({
156
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
157
+ })
158
+ @triton.autotune(
159
+ configs=[
160
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
161
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
162
+ for num_stages in [2, 3, 4, 5]
163
+ ],
164
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
165
+ )
166
+ @triton.jit(do_not_specialize=['T'])
167
+ def parallel_softpick_attn_bwd_kernel_dq(
168
+ q,
169
+ k,
170
+ v,
171
+ lse,
172
+ delta,
173
+ do,
174
+ dq,
175
+ scale,
176
+ offsets,
177
+ indices,
178
+ T,
179
+ B: tl.constexpr,
180
+ H: tl.constexpr,
181
+ HQ: tl.constexpr,
182
+ G: tl.constexpr,
183
+ K: tl.constexpr,
184
+ V: tl.constexpr,
185
+ BT: tl.constexpr,
186
+ BS: tl.constexpr,
187
+ BK: tl.constexpr,
188
+ BV: tl.constexpr,
189
+ USE_OFFSETS: tl.constexpr
190
+ ):
191
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
192
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
193
+ i_h = i_hq // G
194
+
195
+ if USE_OFFSETS:
196
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
197
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
198
+ T = eos - bos
199
+ else:
200
+ i_n = i_b
201
+ bos, eos = i_n * T, i_n * T + T
202
+
203
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
204
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
205
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
206
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
207
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
208
+
209
+ # [BT, BK]
210
+ b_q = tl.load(p_q, boundary_check=(0, 1))
211
+ b_q = (b_q * scale).to(b_q.dtype)
212
+ # [BT, BV]
213
+ b_do = tl.load(p_do, boundary_check=(0, 1))
214
+ # [BT]
215
+ b_lse = tl.load(p_lse, boundary_check=(0,))
216
+ b_delta = tl.load(p_delta, boundary_check=(0,))
217
+
218
+ # [BT, BK]
219
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
220
+ for i_s in range(0, i_t * BT, BS):
221
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
222
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
223
+ # [BK, BS]
224
+ b_k = tl.load(p_k, boundary_check=(0, 1))
225
+ # [BV, BS]
226
+ b_v = tl.load(p_v, boundary_check=(0, 1))
227
+
228
+ # [BT, BS]
229
+ b_s = tl.dot(b_q, b_k)
230
+ b_e = exp(b_s - b_lse[:, None])
231
+
232
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
233
+ b_dp = tl.dot(b_do, b_v)
234
+ # [BT, BS]
235
+ b_step = tl.where(b_s > 0, b_dp, 0)
236
+ b_sign = tl.where(b_s > 0, b_delta[:, None], -b_delta[:, None])
237
+ b_ds = b_e * (b_step.to(tl.float32) - b_sign)
238
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
239
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
240
+
241
+ # [BT]
242
+ o_q = i_t * BT + tl.arange(0, BT)
243
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
244
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
245
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
246
+ # [BS]
247
+ o_k = i_s + tl.arange(0, BS)
248
+ # [BK, BS]
249
+ b_k = tl.load(p_k, boundary_check=(0, 1))
250
+ # [BV, BS]
251
+ b_v = tl.load(p_v, boundary_check=(0, 1))
252
+
253
+ # [BT, BS]
254
+ b_s = tl.dot(b_q, b_k)
255
+ b_e = exp(b_s - b_lse[:, None])
256
+
257
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
258
+ b_dp = tl.dot(b_do, b_v)
259
+ # [BT, BS]
260
+ b_e = tl.where(o_q[:, None] >= o_k[None, :], b_e, 0)
261
+ b_step = tl.where(b_s > 0, b_dp, 0)
262
+ b_sign = tl.where(b_s > 0, b_delta[:, None], -b_delta[:, None])
263
+ b_ds = b_e * (b_step.to(tl.float32) - b_sign)
264
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
265
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
266
+
267
+ b_dq *= scale
268
+
269
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
270
+
271
+
272
+ @triton.heuristics({
273
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
274
+ })
275
+ @triton.autotune(
276
+ configs=[
277
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
278
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
279
+ for num_stages in [2, 3, 4, 5]
280
+ ],
281
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
282
+ )
283
+ @triton.jit(do_not_specialize=['T'])
284
+ def parallel_softpick_attn_bwd_kernel_dkv(
285
+ q,
286
+ k,
287
+ v,
288
+ lse,
289
+ delta,
290
+ do,
291
+ dk,
292
+ dv,
293
+ offsets,
294
+ indices,
295
+ scale,
296
+ T,
297
+ B: tl.constexpr,
298
+ H: tl.constexpr,
299
+ HQ: tl.constexpr,
300
+ G: tl.constexpr,
301
+ K: tl.constexpr,
302
+ V: tl.constexpr,
303
+ BT: tl.constexpr,
304
+ BS: tl.constexpr,
305
+ BK: tl.constexpr,
306
+ BV: tl.constexpr,
307
+ USE_OFFSETS: tl.constexpr
308
+ ):
309
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
310
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
311
+ i_h = i_hq // G
312
+
313
+ if USE_OFFSETS:
314
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
315
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
316
+ T = eos - bos
317
+ else:
318
+ i_n = i_b
319
+ bos, eos = i_n * T, i_n * T + T
320
+
321
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
322
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
323
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
324
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
325
+
326
+ # [BT, BK]
327
+ b_k = tl.load(p_k, boundary_check=(0, 1))
328
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
329
+ # [BT, BV]
330
+ b_v = tl.load(p_v, boundary_check=(0, 1))
331
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
332
+
333
+ o_k = i_t * BT + tl.arange(0, BT)
334
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
335
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
336
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
337
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
338
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
339
+
340
+ # [BS]
341
+ o_q = i_s + tl.arange(0, BS)
342
+ # [BS, BK]
343
+ b_q = tl.load(p_q, boundary_check=(0, 1))
344
+ b_q = (b_q * scale).to(b_q.dtype)
345
+ # [BS, BV]
346
+ b_do = tl.load(p_do, boundary_check=(0, 1))
347
+ # [BS]
348
+ b_lse = tl.load(p_lse, boundary_check=(0,))
349
+ b_delta = tl.load(p_delta, boundary_check=(0,))
350
+ # [BT, BS]
351
+ b_s = tl.dot(b_k, tl.trans(b_q))
352
+ b_e = exp(b_s - b_lse[None, :])
353
+ b_p = b_e - exp(-b_lse[None, :])
354
+ b_p_r = tl.maximum(b_p, 0.0)
355
+ b_p_r = tl.where(o_k[:, None] <= o_q[None, :], b_p_r, 0)
356
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
357
+ b_dv += tl.dot(b_p_r.to(b_do.dtype), b_do)
358
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
359
+ b_dp = tl.dot(b_v, tl.trans(b_do))
360
+ # [BT, BS]
361
+ b_e = tl.where(o_k[:, None] <= o_q[None, :], b_e, 0)
362
+ b_step = tl.where(b_s > 0, b_dp, 0)
363
+ b_sign = tl.where(b_s > 0, b_delta[None, :], -b_delta[None, :])
364
+ b_ds = b_e * (b_step - b_sign)
365
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
366
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
367
+
368
+ for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS):
369
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
370
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
371
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
372
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
373
+
374
+ # [BS]
375
+ o_q = i_s + tl.arange(0, BS)
376
+ # [BS, BK]
377
+ b_q = tl.load(p_q, boundary_check=(0, 1))
378
+ b_q = (b_q * scale).to(b_q.dtype)
379
+ # [BS, BV]
380
+ b_do = tl.load(p_do, boundary_check=(0, 1))
381
+ # [BS]
382
+ b_lse = tl.load(p_lse, boundary_check=(0,))
383
+ b_delta = tl.load(p_delta, boundary_check=(0,))
384
+ # [BT, BS]
385
+ b_s = tl.dot(b_k, tl.trans(b_q))
386
+ b_e = exp(b_s - b_lse[None, :])
387
+ b_p = b_e - exp(-b_lse[None, :])
388
+ b_p_r = tl.maximum(b_p, 0.0)
389
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
390
+ b_dv += tl.dot(b_p_r.to(b_do.dtype), b_do)
391
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
392
+ b_dp = tl.dot(b_v, tl.trans(b_do))
393
+ # [BT, BS]
394
+ b_step = tl.where(b_s > 0, b_dp, 0)
395
+ b_sign = tl.where(b_s > 0, b_delta[None, :], -b_delta[None, :])
396
+ b_ds = b_e * (b_step - b_sign)
397
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
398
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
399
+
400
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
401
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
402
+
403
+
404
+ def parallel_softpick_attn_fwd(
405
+ q: torch.Tensor,
406
+ k: torch.Tensor,
407
+ v: torch.Tensor,
408
+ scale: float,
409
+ chunk_size: int = 128,
410
+ offsets: Optional[torch.LongTensor] = None,
411
+ indices: Optional[torch.LongTensor] = None,
412
+ ):
413
+ B, T, H, K, V = *k.shape, v.shape[-1]
414
+ HQ = q.shape[2]
415
+ G = HQ // H
416
+ BT = chunk_size
417
+ if check_shared_mem('hopper', q.device.index):
418
+ BS = min(64, max(16, triton.next_power_of_2(T)))
419
+ BK = min(256, max(16, triton.next_power_of_2(K)))
420
+ BV = min(256, max(16, triton.next_power_of_2(V)))
421
+ elif check_shared_mem('ampere', q.device.index):
422
+ BS = min(32, max(16, triton.next_power_of_2(T)))
423
+ BK = min(256, max(16, triton.next_power_of_2(K)))
424
+ BV = min(128, max(16, triton.next_power_of_2(V)))
425
+ else:
426
+ BS = min(32, max(16, triton.next_power_of_2(T)))
427
+ BK = min(256, max(16, triton.next_power_of_2(K)))
428
+ BV = min(64, max(16, triton.next_power_of_2(V)))
429
+ NK = triton.cdiv(K, BK)
430
+ NV = triton.cdiv(V, BV)
431
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
432
+ assert NK == 1, "The key dimension can not be larger than 256"
433
+
434
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
435
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
436
+
437
+ grid = (NV, NT, B * HQ)
438
+ parallel_softpick_attn_fwd_kernel[grid](
439
+ q=q,
440
+ k=k,
441
+ v=v,
442
+ o=o,
443
+ lse=lse,
444
+ scale=scale,
445
+ offsets=offsets,
446
+ indices=indices,
447
+ B=B,
448
+ T=T,
449
+ H=H,
450
+ HQ=HQ,
451
+ G=G,
452
+ K=K,
453
+ V=V,
454
+ BT=BT,
455
+ BS=BS,
456
+ BK=BK,
457
+ BV=BV,
458
+ )
459
+ return o, lse
460
+
461
+
462
+ def parallel_softpick_attn_bwd_preprocess(
463
+ o: torch.Tensor,
464
+ do: torch.Tensor
465
+ ):
466
+ V = o.shape[-1]
467
+ delta = torch.empty_like(o[..., 0], dtype=torch.float32)
468
+ parallel_softpick_attn_bwd_kernel_preprocess[(delta.numel(),)](
469
+ o=o,
470
+ do=do,
471
+ delta=delta,
472
+ B=triton.next_power_of_2(V),
473
+ V=V,
474
+ )
475
+ return delta
476
+
477
+
478
+ def parallel_softpick_attn_bwd(
479
+ q: torch.Tensor,
480
+ k: torch.Tensor,
481
+ v: torch.Tensor,
482
+ o: torch.Tensor,
483
+ lse: torch.Tensor,
484
+ do: torch.Tensor,
485
+ scale: float = None,
486
+ chunk_size: int = 128,
487
+ offsets: Optional[torch.LongTensor] = None,
488
+ indices: Optional[torch.LongTensor] = None,
489
+ ):
490
+ B, T, H, K, V = *k.shape, v.shape[-1]
491
+ HQ = q.shape[2]
492
+ G = HQ // H
493
+ BT = chunk_size
494
+ BS = max(16, triton.next_power_of_2(T))
495
+ BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS)
496
+ BK = max(16, triton.next_power_of_2(K))
497
+ BV = max(16, triton.next_power_of_2(V))
498
+ NV = triton.cdiv(V, BV)
499
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
500
+
501
+ delta = parallel_softpick_attn_bwd_preprocess(o, do)
502
+
503
+ dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
504
+ dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
505
+ dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device)
506
+ grid = (NV, NT, B * HQ)
507
+ parallel_softpick_attn_bwd_kernel_dq[grid](
508
+ q=q,
509
+ k=k,
510
+ v=v,
511
+ lse=lse,
512
+ delta=delta,
513
+ do=do,
514
+ dq=dq,
515
+ offsets=offsets,
516
+ indices=indices,
517
+ scale=scale,
518
+ T=T,
519
+ B=B,
520
+ H=H,
521
+ HQ=HQ,
522
+ G=G,
523
+ K=K,
524
+ V=V,
525
+ BT=BT,
526
+ BS=BS,
527
+ BK=BK,
528
+ BV=BV
529
+ )
530
+ parallel_softpick_attn_bwd_kernel_dkv[grid](
531
+ q=q,
532
+ k=k,
533
+ v=v,
534
+ lse=lse,
535
+ delta=delta,
536
+ do=do,
537
+ dk=dk,
538
+ dv=dv,
539
+ offsets=offsets,
540
+ indices=indices,
541
+ scale=scale,
542
+ T=T,
543
+ B=B,
544
+ H=H,
545
+ HQ=HQ,
546
+ G=G,
547
+ K=K,
548
+ V=V,
549
+ BT=BT,
550
+ BS=BS,
551
+ BK=BK,
552
+ BV=BV
553
+ )
554
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
555
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
556
+ return dq, dk, dv
557
+
558
+
559
+ @torch.compile
560
+ class ParallelSoftpickAttentionFunction(torch.autograd.Function):
561
+
562
+ @staticmethod
563
+ @contiguous
564
+ @autocast_custom_fwd
565
+ def forward(ctx, q, k, v, scale, offsets):
566
+ ctx.dtype = q.dtype
567
+
568
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
569
+ # 2-d indices denoting the offsets of chunks in each sequence
570
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
571
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
572
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
573
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
574
+
575
+ o, lse = parallel_softpick_attn_fwd(
576
+ q=q,
577
+ k=k,
578
+ v=v,
579
+ scale=scale,
580
+ chunk_size=chunk_size,
581
+ offsets=offsets,
582
+ indices=indices
583
+ )
584
+ ctx.save_for_backward(q, k, v, o, lse)
585
+ ctx.chunk_size = chunk_size
586
+ ctx.offsets = offsets
587
+ ctx.indices = indices
588
+ ctx.scale = scale
589
+ return o.to(q.dtype)
590
+
591
+ @staticmethod
592
+ @contiguous
593
+ @autocast_custom_bwd
594
+ def backward(ctx, do):
595
+ q, k, v, o, lse = ctx.saved_tensors
596
+ dq, dk, dv = parallel_softpick_attn_bwd(
597
+ q=q,
598
+ k=k,
599
+ v=v,
600
+ o=o,
601
+ lse=lse,
602
+ do=do,
603
+ scale=ctx.scale,
604
+ chunk_size=ctx.chunk_size,
605
+ offsets=ctx.offsets,
606
+ indices=ctx.indices
607
+ )
608
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
609
+
610
+
611
+ def parallel_softpick_attn(
612
+ q: torch.Tensor,
613
+ k: torch.Tensor,
614
+ v: torch.Tensor,
615
+ scale: Optional[float] = None,
616
+ cu_seqlens: Optional[torch.LongTensor] = None,
617
+ head_first: bool = False
618
+ ) -> torch.Tensor:
619
+ r"""
620
+ Args:
621
+ q (torch.Tensor):
622
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
623
+ k (torch.Tensor):
624
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
625
+ GQA will be applied if HQ is divisible by H.
626
+ v (torch.Tensor):
627
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
628
+ scale (Optional[int]):
629
+ Scale factor for attention scores.
630
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
631
+ cu_seqlens (torch.LongTensor):
632
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
633
+ consistent with the FlashAttention API.
634
+ head_first (Optional[bool]):
635
+ Whether the inputs are in the head-first format. Default: `False`.
636
+
637
+ Returns:
638
+ o (torch.Tensor):
639
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
640
+ """
641
+ if scale is None:
642
+ scale = k.shape[-1] ** -0.5
643
+ if cu_seqlens is not None:
644
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
645
+ if head_first:
646
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
647
+ o = ParallelSoftpickAttentionFunction.apply(q, k, v, scale, cu_seqlens)
648
+ if head_first:
649
+ o = rearrange(o, 'b t h d -> b h t d')
650
+ return o
fla/ops/based/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (289 Bytes). View file
 
fla/ops/based/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (22.6 kB). View file
 
fla/ops/based/naive.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import rearrange
7
+
8
+
9
+ def naive_parallel_based(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ scale: Optional[float] = None,
14
+ use_norm: bool = True
15
+ ):
16
+ if scale is None:
17
+ scale = q.shape[-1] ** -0.5
18
+ q = q * scale
19
+ attn = q @ k.transpose(-2, -1)
20
+ attn = 1 + attn + 1/2 * (attn ** 2)
21
+ attn.masked_fill_(~torch.tril(torch.ones(
22
+ q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
23
+ o = attn @ v
24
+ if use_norm:
25
+ z = attn.sum(-1)
26
+ return o / (z[..., None] + 1e-6)
27
+ else:
28
+ return o
29
+
30
+
31
+ def naive_chunk_based(q, k, v, chunk_size=256):
32
+ q = q * (q.shape[-1] ** -0.5)
33
+ # compute normalizer.
34
+ k_cumsum = torch.cumsum(k, dim=-2)
35
+ kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3)
36
+ # first
37
+ z = (q * k_cumsum).sum(-1)
38
+ # second order
39
+ z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5
40
+ # zero-th order
41
+ z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :]
42
+
43
+ # compute o
44
+ # constant term
45
+ _o = v.cumsum(-2)
46
+
47
+ q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size)
48
+
49
+ k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
50
+ v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
51
+
52
+ intra_chunk_attn = q @ k.transpose(-2, -1)
53
+ intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2)
54
+ intra_chunk_attn.masked_fill_(~torch.tril(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device)), 0)
55
+ o = intra_chunk_attn @ v
56
+
57
+ # quadractic term
58
+ kv = torch.einsum('b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v)
59
+ kv = kv.cumsum(2)
60
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
61
+
62
+ o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q)
63
+
64
+ # linear term
65
+ kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v)
66
+ kv = kv.cumsum(2)
67
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
68
+ o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q)
69
+
70
+ o = rearrange(o, 'b h n c d -> b h (n c) d')
71
+ o = o + _o
72
+ return o / (z[..., None] + 1e-6)
fla/ops/based/parallel.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
11
+
12
+ # Based: An Educational and Effective Sequence Mixer
13
+ # https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
14
+
15
+
16
+ @triton.jit(do_not_specialize=['T'])
17
+ def parallel_based_fwd_kernel(
18
+ q,
19
+ k,
20
+ v,
21
+ o,
22
+ z,
23
+ scale,
24
+ T,
25
+ B: tl.constexpr,
26
+ H: tl.constexpr,
27
+ K: tl.constexpr,
28
+ V: tl.constexpr,
29
+ BTL: tl.constexpr,
30
+ BTS: tl.constexpr,
31
+ BK: tl.constexpr,
32
+ BV: tl.constexpr,
33
+ ):
34
+ # i_c: chunk index. used for sequence parallelism
35
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
36
+ NV = tl.cdiv(V, BV)
37
+ i_k = i_kv // (NV)
38
+ i_v = i_kv % (NV)
39
+
40
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
41
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BTS), (0, 1))
42
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BTS, BV), (1, 0))
43
+
44
+ # [BQ, BD] block Q, in the shared memory throughout the whole kernel
45
+ b_q = tl.load(p_q, boundary_check=(0, 1))
46
+ b_q = (b_q * scale).to(b_q.dtype)
47
+ b_o = tl.zeros([BTL, BV], dtype=tl.float32)
48
+ b_z = tl.zeros([BTL], dtype=tl.float32)
49
+
50
+ # Q block and K block have no overlap
51
+ # no need for mask, thereby saving flops
52
+ for _ in range(0, i_c * BTL, BTS):
53
+ # [BK, BTS]
54
+ b_k = tl.load(p_k, boundary_check=(0, 1))
55
+
56
+ # [BTS, BV]
57
+ b_v = tl.load(p_v, boundary_check=(0, 1))
58
+ # [BTL, BTS]
59
+ b_s = tl.dot(b_q, (b_k), allow_tf32=False)
60
+ b_s = 1 + b_s + 0.5 * b_s * b_s
61
+ b_z += tl.sum(b_s, axis=1)
62
+
63
+ # [BQ, BD]
64
+ b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
65
+ p_k = tl.advance(p_k, (0, BTS))
66
+ p_v = tl.advance(p_v, (BTS, 0))
67
+
68
+ # # rescale interchunk output
69
+ tl.debug_barrier()
70
+ o_q = tl.arange(0, BTL)
71
+ # # sync threads, easy for compiler to optimize
72
+ # tl.debug_barrier()
73
+
74
+ o_k = tl.arange(0, BTS)
75
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))
76
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))
77
+ # Q block and K block have overlap. masks required
78
+ for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
79
+ # [BK, BTS]
80
+ b_k = tl.load(p_k, boundary_check=(0, 1))
81
+ # [BTS, BV]
82
+ b_v = tl.load(p_v, boundary_check=(0, 1))
83
+ # [BTL, BTS]
84
+ m_s = o_q[:, None] >= o_k[None, :]
85
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
86
+ b_s = 1 + b_s + 0.5 * b_s * b_s
87
+ b_s = tl.where(m_s, b_s, 0)
88
+ b_z += tl.sum(b_s, axis=1)
89
+ # [BTL, BV]
90
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
91
+
92
+ p_k = tl.advance(p_k, (0, BTS))
93
+ p_v = tl.advance(p_v, (BTS, 0))
94
+ o_k += BTS
95
+
96
+ p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
97
+ p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)
98
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
99
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c * BTL + tl.arange(0, BTL)) < T))
100
+
101
+
102
+ @triton.jit
103
+ def _parallel_based_bwd_dq(
104
+ i_bh,
105
+ i_c,
106
+ i_k,
107
+ i_v,
108
+ q,
109
+ k,
110
+ v,
111
+ do,
112
+ dz,
113
+ dq,
114
+ scale,
115
+ T,
116
+ B: tl.constexpr,
117
+ H: tl.constexpr,
118
+ BTL: tl.constexpr,
119
+ BTS: tl.constexpr,
120
+ BK: tl.constexpr,
121
+ BV: tl.constexpr,
122
+ K: tl.constexpr,
123
+ V: tl.constexpr,
124
+ ):
125
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
126
+ p_q = tl.make_block_ptr(q + (i_bh) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
127
+ b_q = tl.load(p_q, boundary_check=(0, 1))
128
+ b_q = (b_q * scale).to(b_q.dtype)
129
+
130
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
131
+ b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
132
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BTS, BK), (1, 0))
133
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, 0), (BV, BTS), (0, 1))
134
+ p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)
135
+ b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)
136
+
137
+ for _ in range(0, i_c * BTL, BTS):
138
+ # [BTS, BK]
139
+ b_k = tl.load(p_k, boundary_check=(0, 1))
140
+ # [BV, BTS]
141
+ b_v = tl.load(p_v, boundary_check=(0, 1))
142
+ # [BTL, BTS]
143
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
144
+ if i_v == 0:
145
+ b_ds += b_dz[:, None]
146
+ else:
147
+ b_ds = b_ds
148
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
149
+ # [BQ, BD]
150
+ b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False)
151
+ p_k = tl.advance(p_k, (BTS, 0))
152
+ p_v = tl.advance(p_v, (0, BTS))
153
+
154
+ b_dq *= scale
155
+ o_q = tl.arange(0, BTL)
156
+ o_k = tl.arange(0, BTS)
157
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))
158
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))
159
+ # Q block and K block have overlap. masks required
160
+ for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):
161
+ # [BTS, BK]
162
+ b_k = tl.load(p_k, boundary_check=(0, 1))
163
+ # [BV, BTS]
164
+ b_v = tl.load(p_v, boundary_check=(0, 1))
165
+ # [BTL, BTS]
166
+ m_s = o_q[:, None] >= o_k[None, :]
167
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
168
+ if i_v == 0:
169
+ b_ds += b_dz[:, None]
170
+ else:
171
+ b_ds = b_ds
172
+ b_ds = tl.where(m_s, b_ds, 0) * scale
173
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
174
+ b_s = tl.where(m_s, b_s, 0)
175
+ # [BTL, BK]
176
+ b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), b_k, allow_tf32=False)
177
+ p_k = tl.advance(p_k, (BTS, 0))
178
+ p_v = tl.advance(p_v, (0, BTS))
179
+ o_k += BTS
180
+ p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
181
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
182
+ return
183
+
184
+
185
+ @triton.jit
186
+ def _parallel_based_bwd_dkv(
187
+ i_bh,
188
+ i_c,
189
+ i_k,
190
+ i_v,
191
+ q,
192
+ k,
193
+ v,
194
+ do,
195
+ dz,
196
+ dk,
197
+ dv,
198
+ scale,
199
+ T,
200
+ B: tl.constexpr,
201
+ H: tl.constexpr,
202
+ BTL: tl.constexpr,
203
+ BTS: tl.constexpr,
204
+ BK: tl.constexpr,
205
+ BV: tl.constexpr,
206
+ K: tl.constexpr,
207
+ V: tl.constexpr,
208
+ ):
209
+ # compute dk dv
210
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))
211
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))
212
+ b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1))
213
+ b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros([BTL, BV], dtype=tl.float32)
214
+
215
+ for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
216
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1))
217
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1))
218
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
219
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]
220
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS]
221
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
222
+ b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale # [BTL, BTS]
223
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
224
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
225
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
226
+ if i_v == 0:
227
+ b_ds += b_dz[None, :] * scale
228
+ else:
229
+ b_ds = b_ds
230
+ b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
231
+
232
+ tl.debug_barrier()
233
+ o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
234
+ for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
235
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1))
236
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1))
237
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
238
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
239
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
240
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
241
+ # [BK, BQ]
242
+ m_s = o_k[:, None] <= o_q[None, :]
243
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
244
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
245
+ b_s = tl.where(m_s, b_s, 0)
246
+ b_s2 = tl.where(m_s, b_s2, 0)
247
+
248
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False)
249
+ if i_v == 0:
250
+ b_ds += b_dz[None, :]
251
+ else:
252
+ b_ds = b_ds
253
+ b_ds = tl.where(m_s, b_ds, 0) * scale
254
+ # [BK, BD]
255
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
256
+ b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
257
+ o_q += BTS
258
+
259
+ p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
260
+ p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
261
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
262
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
263
+ return
264
+
265
+
266
+ @triton.jit(do_not_specialize=['T'])
267
+ def parallel_based_bwd_kernel(
268
+ q,
269
+ k,
270
+ v,
271
+ do,
272
+ dz,
273
+ dq,
274
+ dk,
275
+ dv,
276
+ scale,
277
+ T,
278
+ B: tl.constexpr,
279
+ H: tl.constexpr,
280
+ K: tl.constexpr,
281
+ V: tl.constexpr,
282
+ BTL: tl.constexpr,
283
+ BTS: tl.constexpr,
284
+ BK: tl.constexpr,
285
+ BV: tl.constexpr,
286
+ ):
287
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
288
+ NV = tl.cdiv(V, BV)
289
+ i_k = i_kv // (NV)
290
+ i_v = i_kv % NV
291
+ _parallel_based_bwd_dq(
292
+ i_bh, i_c, i_k, i_v,
293
+ q, k, v, do, dz, dq,
294
+ scale, T, B, H, BTL, BTS, BK, BV, K, V
295
+ )
296
+ tl.debug_barrier()
297
+ _parallel_based_bwd_dkv(
298
+ i_bh, i_c, i_k, i_v,
299
+ q, k, v, do, dz, dk, dv,
300
+ scale, T, B, H, BTL, BTS, BK, BV, K, V
301
+ )
302
+
303
+
304
+ class ParallelBasedFunction(torch.autograd.Function):
305
+
306
+ @staticmethod
307
+ @input_guard
308
+ @autocast_custom_fwd
309
+ def forward(ctx, q, k, v, scale):
310
+ BTL, BTS = 128, 32
311
+ assert BTL % BTS == 0
312
+ # assert q.shape[-1] % 16 == 0
313
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
314
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
315
+ BK, BV = max(BK, 16), max(BV, 16)
316
+ B, H, T, K, V = *k.shape, v.shape[-1]
317
+ num_stages = 2
318
+ num_warps = 4
319
+ NK = triton.cdiv(K, BK)
320
+ NV = triton.cdiv(V, BV)
321
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
322
+
323
+ assert NK == 1, "will encounter some synchronization issue if not."
324
+
325
+ o = torch.empty(NK, B, H, T, V, device=q.device)
326
+ z = torch.empty(NK, B, H, T, device=q.device)
327
+ parallel_based_fwd_kernel[grid](
328
+ q, k, v, o, z,
329
+ scale,
330
+ B=B,
331
+ H=H,
332
+ T=T,
333
+ K=K,
334
+ V=V,
335
+ BTL=BTL,
336
+ BTS=BTS,
337
+ BK=BK,
338
+ BV=BV,
339
+ num_warps=num_warps,
340
+ num_stages=num_stages
341
+ )
342
+ ctx.save_for_backward(q, k, v)
343
+ ctx.scale = scale
344
+ return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
345
+
346
+ @staticmethod
347
+ @input_guard
348
+ @autocast_custom_bwd
349
+ def backward(ctx, do, dz):
350
+ q, k, v = ctx.saved_tensors
351
+ scale = ctx.scale
352
+ BTL, BTS = 64, 32
353
+ assert BTL % BTS == 0
354
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
355
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
356
+ BK, BV = max(BK, 16), max(BV, 16)
357
+ B, H, T, K, V = *k.shape, v.shape[-1]
358
+ num_stages = 2
359
+ num_warps = 4
360
+ NK = triton.cdiv(K, BK)
361
+ NV = triton.cdiv(V, BV)
362
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
363
+
364
+ assert NK == 1, "will encounter some synchronization issue if not"
365
+
366
+ dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
367
+ dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
368
+ dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device)
369
+
370
+ parallel_based_bwd_kernel[grid](
371
+ q, k, v, do, dz, dq, dk, dv,
372
+ scale,
373
+ B=B,
374
+ H=H,
375
+ T=T,
376
+ K=K,
377
+ V=V,
378
+ BTL=BTL,
379
+ BTS=BTS,
380
+ BK=BK,
381
+ BV=BV,
382
+ num_warps=num_warps,
383
+ num_stages=num_stages
384
+ )
385
+
386
+ return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
387
+
388
+
389
+ triton_parallel_based = ParallelBasedFunction.apply
390
+
391
+
392
+ def parallel_based(
393
+ q: torch.Tensor,
394
+ k: torch.Tensor,
395
+ v: torch.Tensor,
396
+ scale: Optional[float] = None,
397
+ use_norm: bool = True,
398
+ head_first: bool = True
399
+ ):
400
+ assert q.shape[-1] <= 128, "only support feature dim up to 128"
401
+ if scale is None:
402
+ scale = q.shape[-1] ** -0.5
403
+ if not head_first:
404
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
405
+ o, z = triton_parallel_based(q, k, v, scale)
406
+ if use_norm:
407
+ o = o / (z[..., None] + 1e-6)
408
+ if not head_first:
409
+ o = o.transpose(1, 2)
410
+ return o.to(q.dtype)
fla/ops/common/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (142 Bytes). View file
 
fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc ADDED
Binary file (23.9 kB). View file
 
fla/ops/common/__pycache__/chunk_o.cpython-312.pyc ADDED
Binary file (37 kB). View file
 
fla/ops/common/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (32.4 kB). View file
 
fla/ops/common/__pycache__/utils.cpython-312.pyc ADDED
Binary file (4.42 kB). View file
 
fla/ops/delta_rule/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (364 Bytes). View file
 
fla/ops/delta_rule/__pycache__/wy_fast.cpython-312.pyc ADDED
Binary file (20.5 kB). View file
 
fla/ops/delta_rule/naive.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ def delta_rule_recurrence(q, k, v, beta, initial_state=None, output_final_state=True):
8
+ orig_dtype = q.dtype
9
+ b, h, l, d_k = q.shape
10
+ q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta])
11
+ d_v = v.shape[-1]
12
+ o = torch.zeros_like(v)
13
+ S = torch.zeros(b, h, d_k, d_v).to(v)
14
+ q = q * (d_k ** -0.5)
15
+
16
+ if beta.ndim < v.ndim:
17
+ beta = beta[..., None]
18
+
19
+ if initial_state is not None:
20
+ S += initial_state
21
+
22
+ for i in range(l):
23
+ _k = k[:, :, i]
24
+ _q = q[:, :, i]
25
+ _v = v[:, :, i].clone()
26
+ beta_i = beta[:, :, i]
27
+ _v = _v - (S.clone() * _k[..., None]).sum(-2)
28
+ _v = _v * beta_i
29
+ S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
30
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
31
+ S = None if output_final_state is False else S
32
+ return o.to(orig_dtype), S
33
+
34
+
35
+ def delta_rule_chunkwise(q, k, v, beta, chunk_size=32):
36
+ b, h, l, d_k = q.shape
37
+ d_v = v.shape[-1]
38
+ q = q * (d_k ** -0.5)
39
+ v = v * beta[..., None]
40
+ k_beta = k * beta[..., None]
41
+
42
+ assert l % chunk_size == 0
43
+
44
+ # compute (I - tri(diag(beta) KK^T))^{-1}
45
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
46
+ q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta])
47
+ attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
48
+ for i in range(1, chunk_size):
49
+ attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
50
+ attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
51
+
52
+ u = attn @ v
53
+ w = attn @ k_beta
54
+ S = k.new_zeros(b, h, d_k, d_v)
55
+ o = torch.zeros_like(v)
56
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
57
+ for i in range(0, l // chunk_size):
58
+ q_i, k_i = q[:, :, i], k[:, :, i]
59
+ attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0)
60
+ u_i = u[:, :, i] - w[:, :, i] @ S
61
+ o_inter = q_i @ S
62
+ o[:, :, i] = o_inter + attn @ u_i
63
+ S = S + k_i.transpose(-1, -2) @ u_i
64
+
65
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
66
+
67
+
68
+ def delta_rule_parallel(q, k, v, beta, BM=128, BN=32):
69
+ b, h, l, d_k = q.shape
70
+ # d_v = v.shape[-1]
71
+ q = q * (d_k ** -0.5)
72
+ v = v * beta[..., None]
73
+ k_beta = k * beta[..., None]
74
+ # compute (I - tri(diag(beta) KK^T))^{-1}
75
+ q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta])
76
+ mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0)
77
+ T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
78
+ for i in range(1, BN):
79
+ T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2)
80
+ T = T + torch.eye(BN, dtype=torch.float, device=q.device)
81
+
82
+ mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1)
83
+ A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T
84
+ o_intra = A_local @ v
85
+
86
+ # apply cumprod transition matrices on k to the last position within the chunk
87
+ k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta
88
+ # apply cumprod transition matrices on q to the first position within the chunk
89
+ q = q - A_local @ k_beta
90
+ o_intra = A_local @ v
91
+
92
+ A = torch.zeros(b, h, l, l, device=q.device)
93
+
94
+ q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra])
95
+ o = torch.empty_like(v)
96
+ for i in range(0, l, BM):
97
+ q_i = q[:, :, i:i+BM]
98
+ o_i = o_intra[:, :, i:i+BM]
99
+ # intra block
100
+ for j in range(i + BM - 2 * BN, i-BN, -BN):
101
+ k_j = k[:, :, j:j+BN]
102
+ A_ij = q_i @ k_j.transpose(-1, -2)
103
+ mask = torch.arange(i, i+BM) >= (j + BN)
104
+ A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0)
105
+ A[:, :, i:i+BM, j:j+BN] = A_ij
106
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
107
+ o_i += A_ij @ v[:, :, j:j+BN]
108
+ # inter block
109
+ for j in range(i - BN, -BN, -BN):
110
+ k_j = k[:, :, j:j+BN]
111
+ A_ij = q_i @ k_j.transpose(-1, -2)
112
+ A[:, :, i:i+BM, j:j+BN] = A_ij
113
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
114
+ o_i += A_ij @ v[:, :, j:j+BN]
115
+ o[:, :, i:i+BM] = o_i
116
+
117
+ for i in range(0, l//BN):
118
+ A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i]
119
+
120
+ return o, A
fla/ops/forgetting_attn/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (39 kB). View file
 
fla/ops/forgetting_attn/parallel.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils import chunk_global_cumsum, chunk_local_cumsum
13
+ from fla.ops.utils.op import div, exp, log
14
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
24
+ for num_stages in [2, 3, 4, 5]
25
+ ],
26
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
27
+ )
28
+ @triton.jit
29
+ def parallel_forgetting_attn_fwd_kernel(
30
+ q,
31
+ k,
32
+ v,
33
+ g,
34
+ o,
35
+ lse,
36
+ scale,
37
+ offsets,
38
+ indices,
39
+ T,
40
+ B: tl.constexpr,
41
+ H: tl.constexpr,
42
+ HQ: tl.constexpr,
43
+ G: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BS: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr
51
+ ):
52
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
54
+ i_h = i_hq // G
55
+
56
+ if USE_OFFSETS:
57
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
58
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
59
+ T = eos - bos
60
+ else:
61
+ i_n = i_b
62
+ bos, eos = i_n * T, i_n * T + T
63
+
64
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
65
+ p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
66
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
67
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
68
+
69
+ # the Q block is kept in the shared memory throughout the whole kernel
70
+ # [BT, BK]
71
+ b_q = tl.load(p_q, boundary_check=(0, 1))
72
+ b_q = (b_q * scale).to(b_q.dtype)
73
+ # [BT,]
74
+ b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
75
+ # [BT, BV]
76
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
77
+
78
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
79
+ b_acc = tl.zeros([BT], dtype=tl.float32)
80
+
81
+ # [BT]
82
+ o_q = i_t * BT + tl.arange(0, BT)
83
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
84
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
85
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
86
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
87
+
88
+ # [BS]
89
+ o_k = i_s + tl.arange(0, BS)
90
+ # [BK, BS]
91
+ b_k = tl.load(p_k, boundary_check=(0, 1))
92
+ # [BS, BV]
93
+ b_v = tl.load(p_v, boundary_check=(0, 1))
94
+ # [BS,]
95
+ b_gk = tl.load(p_gk, boundary_check=(0,))
96
+ # [BT, BS]
97
+ b_s = tl.dot(b_q, b_k) + b_gq[:, None] - b_gk[None, :]
98
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
99
+
100
+ # [BT]
101
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
102
+ b_r = exp(b_mp - b_m)
103
+ # [BT, BS]
104
+ b_p = exp(b_s - b_m[:, None])
105
+ # [BT]
106
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
107
+ # [BT, BV]
108
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
109
+
110
+ b_mp = b_m
111
+
112
+ for i_s in range(i_t * BT - BS, -BS, -BS):
113
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
114
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
115
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
116
+
117
+ # [BK, BS]
118
+ b_k = tl.load(p_k, boundary_check=(0, 1))
119
+ # [BS, BV]
120
+ b_v = tl.load(p_v, boundary_check=(0, 1))
121
+ # [BS,]
122
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
123
+
124
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
125
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
126
+ # [BT, BS]
127
+ b_s = tl.dot(b_q, b_k) + b_gq[:, None] + (b_gn - b_gk)[None, :]
128
+
129
+ b_gq += b_gn - b_gp
130
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
131
+ b_r = exp(b_mp - b_m)
132
+ # [BT, BS]
133
+ b_p = exp(b_s - b_m[:, None])
134
+ # [BT]
135
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
136
+ # [BT, BV]
137
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
138
+
139
+ b_mp = b_m
140
+
141
+ b_o = div(b_o, b_acc[:, None])
142
+ b_m += log(b_acc)
143
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
144
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
145
+
146
+
147
+ @triton.jit
148
+ def parallel_forgetting_attn_bwd_kernel_preprocess(
149
+ o,
150
+ do,
151
+ delta,
152
+ B: tl.constexpr,
153
+ V: tl.constexpr
154
+ ):
155
+ i_n = tl.program_id(0)
156
+ o_d = tl.arange(0, B)
157
+ m_d = o_d < V
158
+
159
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
160
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
161
+ b_delta = tl.sum(b_o * b_do)
162
+
163
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
164
+
165
+
166
+ @triton.heuristics({
167
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
168
+ })
169
+ @triton.autotune(
170
+ configs=[
171
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
172
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
173
+ for num_stages in [2, 3, 4]
174
+ ],
175
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
176
+ )
177
+ @triton.jit(do_not_specialize=['T'])
178
+ def parallel_forgetting_attn_bwd_kernel_dq(
179
+ q,
180
+ k,
181
+ v,
182
+ g,
183
+ lse,
184
+ delta,
185
+ do,
186
+ dq,
187
+ dg,
188
+ scale,
189
+ offsets,
190
+ indices,
191
+ T,
192
+ B: tl.constexpr,
193
+ H: tl.constexpr,
194
+ HQ: tl.constexpr,
195
+ G: tl.constexpr,
196
+ K: tl.constexpr,
197
+ V: tl.constexpr,
198
+ BT: tl.constexpr,
199
+ BS: tl.constexpr,
200
+ BK: tl.constexpr,
201
+ BV: tl.constexpr,
202
+ USE_OFFSETS: tl.constexpr
203
+ ):
204
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
205
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
206
+ i_h = i_hq // G
207
+
208
+ if USE_OFFSETS:
209
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
210
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
211
+ T = eos - bos
212
+ else:
213
+ i_n = i_b
214
+ bos, eos = i_n * T, i_n * T + T
215
+
216
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
217
+ p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
218
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
219
+ p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,))
220
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
221
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
222
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
223
+
224
+ # [BT, BK]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale).to(b_q.dtype)
227
+ # [BT, BV]
228
+ b_do = tl.load(p_do, boundary_check=(0, 1))
229
+ # [BT]
230
+ b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
231
+ b_lse = tl.load(p_lse, boundary_check=(0,))
232
+ b_delta = tl.load(p_delta, boundary_check=(0,))
233
+
234
+ # [BT]
235
+ o_q = i_t * BT + tl.arange(0, BT)
236
+ # [BT, BK]
237
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
238
+ # [BT]
239
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
240
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
241
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
242
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
243
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
244
+
245
+ # [BS]
246
+ o_k = i_s + tl.arange(0, BS)
247
+ # [BK, BS]
248
+ b_k = tl.load(p_k, boundary_check=(0, 1))
249
+ # [BV, BS]
250
+ b_v = tl.load(p_v, boundary_check=(0, 1))
251
+ # [BS,]
252
+ b_gk = tl.load(p_gk, boundary_check=(0,))
253
+ # [BT, BS]
254
+ b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] - b_gk[None, :]
255
+ b_p = exp(tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf')))
256
+
257
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
258
+ b_dp = tl.dot(b_do, b_v)
259
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
260
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
261
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
262
+ # [BT]
263
+ b_dg += tl.sum(b_ds, 1)
264
+
265
+ for i_s in range(i_t * BT - BS, -BS, -BS):
266
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
267
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
268
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
269
+
270
+ # [BK, BS]
271
+ b_k = tl.load(p_k, boundary_check=(0, 1))
272
+ # [BV, BS]
273
+ b_v = tl.load(p_v, boundary_check=(0, 1))
274
+ # [BS,]
275
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
276
+
277
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
278
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
279
+ # [BT, BS]
280
+ b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] + (b_gn - b_gk)[None, :]
281
+ b_p = exp(b_s)
282
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
283
+ b_dp = tl.dot(b_do, b_v)
284
+ b_ds = b_p * (b_dp - b_delta[:, None])
285
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
286
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
287
+ # [BT]
288
+ b_dg += tl.sum(b_ds, 1)
289
+
290
+ b_gq += b_gn - b_gp
291
+
292
+ b_dq *= scale
293
+
294
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
295
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
296
+
297
+
298
+ @triton.heuristics({
299
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
300
+ })
301
+ @triton.autotune(
302
+ configs=[
303
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
304
+ for num_warps in [1, 2, 4, 8]
305
+ for num_stages in [2, 3, 4]
306
+ ],
307
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
308
+ )
309
+ @triton.jit(do_not_specialize=['T'])
310
+ def parallel_forgetting_attn_bwd_kernel_dkv(
311
+ q,
312
+ k,
313
+ v,
314
+ g,
315
+ lse,
316
+ delta,
317
+ do,
318
+ dk,
319
+ dv,
320
+ dg,
321
+ offsets,
322
+ indices,
323
+ scale,
324
+ T,
325
+ B: tl.constexpr,
326
+ H: tl.constexpr,
327
+ HQ: tl.constexpr,
328
+ G: tl.constexpr,
329
+ K: tl.constexpr,
330
+ V: tl.constexpr,
331
+ BT: tl.constexpr,
332
+ BS: tl.constexpr,
333
+ BK: tl.constexpr,
334
+ BV: tl.constexpr,
335
+ USE_OFFSETS: tl.constexpr
336
+ ):
337
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
338
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
339
+ i_h = i_hq // G
340
+
341
+ if USE_OFFSETS:
342
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
343
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
344
+ T = eos - bos
345
+ else:
346
+ i_n = i_b
347
+ bos, eos = i_n * T, i_n * T + T
348
+
349
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
350
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
351
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
352
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
353
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
354
+ p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,))
355
+
356
+ # [BT, BK]
357
+ b_k = tl.load(p_k, boundary_check=(0, 1))
358
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
359
+ # [BT, BV]
360
+ b_v = tl.load(p_v, boundary_check=(0, 1))
361
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
362
+ # [BT]
363
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
364
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
365
+
366
+ o_k = i_t * BT + tl.arange(0, BT)
367
+ m_k = o_k < T
368
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
369
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
370
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
371
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
372
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
373
+ p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
374
+
375
+ # [BS]
376
+ o_q = i_s + tl.arange(0, BS)
377
+ # [BS, BK]
378
+ b_q = tl.load(p_q, boundary_check=(0, 1))
379
+ b_q = (b_q * scale).to(b_q.dtype)
380
+ # [BS, BV]
381
+ b_do = tl.load(p_do, boundary_check=(0, 1))
382
+ # [BS]
383
+ b_lse = tl.load(p_lse, boundary_check=(0,))
384
+ b_delta = tl.load(p_delta, boundary_check=(0,))
385
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
386
+
387
+ m_q = o_q < T
388
+ m_s = (o_k[:, None] <= o_q[None, :]) & m_k[:, None] & m_q[None, :]
389
+ # [BT, BS]
390
+ b_s = tl.dot(b_k, tl.trans(b_q)) - b_gk[:, None] + (b_gq - b_lse)[None, :]
391
+ b_p = tl.where(m_s, exp(b_s), 0)
392
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
393
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
394
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
395
+ b_dp = tl.dot(b_v, tl.trans(b_do))
396
+ # [BT, BS]
397
+ b_ds = b_p * (b_dp - b_delta[None, :])
398
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
399
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
400
+ # [BT]
401
+ b_dg -= tl.sum(b_ds, 1)
402
+
403
+ b_gk -= tl.load(g + (bos + min((i_t + 1) * BT, T) - 1) * HQ + i_hq).to(tl.float32)
404
+ for i_s in range((i_t + 1) * BT, T, BS):
405
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
406
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
407
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
408
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
409
+ p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
410
+
411
+ # [BS]
412
+ o_q = i_s + tl.arange(0, BS)
413
+ # [BS, BK]
414
+ b_q = tl.load(p_q, boundary_check=(0, 1))
415
+ b_q = (b_q * scale).to(b_q.dtype)
416
+ # [BS, BV]
417
+ b_do = tl.load(p_do, boundary_check=(0, 1))
418
+ # [BS]
419
+ b_lse = tl.load(p_lse, boundary_check=(0,))
420
+ b_delta = tl.load(p_delta, boundary_check=(0,))
421
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
422
+
423
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
424
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
425
+ # [BT, BS]
426
+ b_s = tl.dot(b_k, tl.trans(b_q)) - (b_gk + b_gp)[:, None] + (b_gq - b_lse)[None, :]
427
+ b_p = exp(b_s)
428
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
429
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
430
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
431
+ b_dp = tl.dot(b_v, tl.trans(b_do))
432
+ # [BT, BS]
433
+ b_ds = b_p * (b_dp - b_delta[None, :])
434
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
435
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
436
+ # [BT]
437
+ b_dg -= tl.sum(b_ds, 1)
438
+
439
+ b_gk -= b_gn - b_gp
440
+
441
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
442
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
443
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
444
+
445
+
446
+ def parallel_forgetting_attn_fwd(
447
+ q: torch.Tensor,
448
+ k: torch.Tensor,
449
+ v: torch.Tensor,
450
+ g: torch.Tensor,
451
+ scale: float,
452
+ chunk_size: int = 128,
453
+ offsets: Optional[torch.LongTensor] = None,
454
+ indices: Optional[torch.LongTensor] = None,
455
+ ):
456
+ B, T, H, K, V = *k.shape, v.shape[-1]
457
+ HQ = q.shape[2]
458
+ G = HQ // H
459
+ BT = chunk_size
460
+ BK = max(16, triton.next_power_of_2(K))
461
+ assert V <= 256, "V must be less than or equal to 256"
462
+ if check_shared_mem('hopper'):
463
+ BS = min(64, max(16, triton.next_power_of_2(T)))
464
+ else:
465
+ BS = min(32, max(16, triton.next_power_of_2(T)))
466
+ BV = min(256, max(16, triton.next_power_of_2(V)))
467
+ NV = triton.cdiv(V, BV)
468
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
469
+
470
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
471
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
472
+
473
+ grid = (NV, NT, B * HQ)
474
+ parallel_forgetting_attn_fwd_kernel[grid](
475
+ q=q,
476
+ k=k,
477
+ v=v,
478
+ g=g,
479
+ o=o,
480
+ lse=lse,
481
+ scale=scale,
482
+ offsets=offsets,
483
+ indices=indices,
484
+ B=B,
485
+ T=T,
486
+ H=H,
487
+ HQ=HQ,
488
+ G=G,
489
+ K=K,
490
+ V=V,
491
+ BT=BT,
492
+ BS=BS,
493
+ BK=BK,
494
+ BV=BV,
495
+ )
496
+ return o, lse
497
+
498
+
499
+ def parallel_forgetting_attn_bwd_preprocess(
500
+ o: torch.Tensor,
501
+ do: torch.Tensor
502
+ ):
503
+ V = o.shape[-1]
504
+ delta = torch.empty_like(o[..., 0], dtype=torch.float)
505
+ parallel_forgetting_attn_bwd_kernel_preprocess[(delta.numel(),)](
506
+ o=o,
507
+ do=do,
508
+ delta=delta,
509
+ B=triton.next_power_of_2(V),
510
+ V=V,
511
+ )
512
+ return delta
513
+
514
+
515
+ def parallel_forgetting_attn_bwd(
516
+ q: torch.Tensor,
517
+ k: torch.Tensor,
518
+ v: torch.Tensor,
519
+ g: torch.Tensor,
520
+ o: torch.Tensor,
521
+ lse: torch.Tensor,
522
+ do: torch.Tensor,
523
+ scale: float = None,
524
+ chunk_size: int = 128,
525
+ offsets: Optional[torch.LongTensor] = None,
526
+ indices: Optional[torch.LongTensor] = None,
527
+ ):
528
+ B, T, H, K, V = *k.shape, v.shape[-1]
529
+ HQ = q.shape[2]
530
+ G = HQ // H
531
+ BT = chunk_size
532
+ BS = min(32, max(16, triton.next_power_of_2(T)))
533
+ BK = max(16, triton.next_power_of_2(K))
534
+ BV = max(16, triton.next_power_of_2(V))
535
+ NV = triton.cdiv(V, BV)
536
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
537
+
538
+ delta = parallel_forgetting_attn_bwd_preprocess(o, do)
539
+ dq = q.new_empty(B, T, HQ, K, dtype=q.dtype)
540
+ dk = q.new_empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float)
541
+ dv = q.new_empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float)
542
+ dg = q.new_empty(g.shape, dtype=torch.float)
543
+ # NOTE: the original `dg` can be destroyed during autotuning
544
+ # this is [a known triton issue](https://github.com/triton-lang/triton/issues/5082), which will be fixed in 3.3 (?)
545
+ # so we need to make a copy of `dg`
546
+ dg2 = q.new_empty(g.shape, dtype=torch.float)
547
+ grid = (NV, NT, B * HQ)
548
+ parallel_forgetting_attn_bwd_kernel_dq[grid](
549
+ q=q,
550
+ k=k,
551
+ v=v,
552
+ g=g,
553
+ lse=lse,
554
+ delta=delta,
555
+ do=do,
556
+ dq=dq,
557
+ dg=dg,
558
+ offsets=offsets,
559
+ indices=indices,
560
+ scale=scale,
561
+ T=T,
562
+ B=B,
563
+ H=H,
564
+ HQ=HQ,
565
+ G=G,
566
+ K=K,
567
+ V=V,
568
+ BT=BT,
569
+ BS=BS,
570
+ BK=BK,
571
+ BV=BV
572
+ )
573
+ parallel_forgetting_attn_bwd_kernel_dkv[grid](
574
+ q=q,
575
+ k=k,
576
+ v=v,
577
+ g=g,
578
+ lse=lse,
579
+ delta=delta,
580
+ do=do,
581
+ dk=dk,
582
+ dv=dv,
583
+ dg=dg2,
584
+ offsets=offsets,
585
+ indices=indices,
586
+ scale=scale,
587
+ T=T,
588
+ B=B,
589
+ H=H,
590
+ HQ=HQ,
591
+ G=G,
592
+ K=K,
593
+ V=V,
594
+ BT=BT,
595
+ BS=BS,
596
+ BK=BK,
597
+ BV=BV
598
+ )
599
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
600
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
601
+ dg = dg.add_(dg2)
602
+ return dq, dk, dv, dg
603
+
604
+
605
+ @torch.compile
606
+ class ParallelForgettingAttentionFunction(torch.autograd.Function):
607
+
608
+ @staticmethod
609
+ @input_guard
610
+ @autocast_custom_fwd
611
+ def forward(ctx, q, k, v, g, scale, offsets):
612
+ ctx.dtype = q.dtype
613
+ if check_shared_mem('hopper'):
614
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
615
+ else:
616
+ chunk_size = min(64, max(16, triton.next_power_of_2(q.shape[1])))
617
+ # 2-d indices denoting the offsets of chunks in each sequence
618
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
619
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
620
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
621
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
622
+
623
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=False)
624
+ o, lse = parallel_forgetting_attn_fwd(
625
+ q=q,
626
+ k=k,
627
+ v=v,
628
+ g=g,
629
+ scale=scale,
630
+ chunk_size=chunk_size,
631
+ offsets=offsets,
632
+ indices=indices
633
+ )
634
+ ctx.save_for_backward(q, k, v, g, o, lse)
635
+ ctx.chunk_size = chunk_size
636
+ ctx.offsets = offsets
637
+ ctx.indices = indices
638
+ ctx.scale = scale
639
+ return o.to(q.dtype)
640
+
641
+ @staticmethod
642
+ @input_guard
643
+ @autocast_custom_bwd
644
+ def backward(ctx, do):
645
+ q, k, v, g, o, lse = ctx.saved_tensors
646
+ dq, dk, dv, dg = parallel_forgetting_attn_bwd(
647
+ q=q,
648
+ k=k,
649
+ v=v,
650
+ g=g,
651
+ o=o,
652
+ lse=lse,
653
+ do=do,
654
+ scale=ctx.scale,
655
+ chunk_size=ctx.chunk_size,
656
+ offsets=ctx.offsets,
657
+ indices=ctx.indices
658
+ )
659
+ dg = chunk_global_cumsum(dg, reverse=True, head_first=False, offsets=ctx.offsets)
660
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(g), None, None, None, None, None, None, None, None
661
+
662
+
663
+ def parallel_forgetting_attn(
664
+ q: torch.Tensor,
665
+ k: torch.Tensor,
666
+ v: torch.Tensor,
667
+ g: torch.Tensor,
668
+ scale: Optional[float] = None,
669
+ cu_seqlens: Optional[torch.LongTensor] = None,
670
+ head_first: bool = False
671
+ ) -> torch.Tensor:
672
+ r"""
673
+ Args:
674
+ q (torch.Tensor):
675
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
676
+ k (torch.Tensor):
677
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
678
+ GQA will be applied if HQ is divisible by H.
679
+ v (torch.Tensor):
680
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
681
+ g (torch.Tensor):
682
+ Forget gates (in **log space**) of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
683
+ scale (Optional[int]):
684
+ Scale factor for attention scores.
685
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
686
+ cu_seqlens (torch.LongTensor):
687
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
688
+ consistent with the FlashAttention API.
689
+ head_first (Optional[bool]):
690
+ Whether the inputs are in the head-first format. Default: `False`.
691
+
692
+ Returns:
693
+ o (torch.Tensor):
694
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
695
+ """
696
+ if scale is None:
697
+ scale = k.shape[-1] ** -0.5
698
+ if cu_seqlens is not None:
699
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
700
+ if g is not None:
701
+ g = g.float()
702
+ if head_first:
703
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
704
+ g = rearrange(g, 'b h t -> b t h')
705
+ o = ParallelForgettingAttentionFunction.apply(q, k, v, g, scale, cu_seqlens)
706
+ if head_first:
707
+ o = rearrange(o, 'b t h d -> b h t d')
708
+ return o
fla/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (322 Bytes). View file
 
fla/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (15.1 kB). View file
 
fla/ops/gated_delta_rule/__pycache__/wy_fast.cpython-312.pyc ADDED
Binary file (45.1 kB). View file
 
fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_h(
31
+ kg,
32
+ v,
33
+ w,
34
+ bg,
35
+ u,
36
+ v_new,
37
+ gk,
38
+ h,
39
+ h0,
40
+ ht,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ NT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ if HEAD_FIRST:
77
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
81
+
82
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
83
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
84
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
85
+ if HEAD_FIRST:
86
+ p_kg = tl.make_block_ptr(kg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_bg = tl.make_block_ptr(bg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
88
+ p_w = tl.make_block_ptr(w + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
89
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
91
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
92
+ else:
93
+ p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
95
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
96
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
98
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
99
+ # [BK, BC]
100
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+ b_w = tl.load(p_w, boundary_check=(0, 1))
103
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
104
+ b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1))
105
+ b_hc += tl.dot(b_kg, b_v)
106
+ b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2)
107
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
108
+
109
+ last_idx = min((i_t + 1) * BT, T) - 1
110
+ if HEAD_FIRST:
111
+ b_g_last = tl.load(gk + i_nh * T * K + last_idx * K + tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32)
112
+ else:
113
+ b_g_last = tl.load(gk + (bos + last_idx) * H * K + i_h * K +
114
+ tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32)
115
+ b_h *= exp(b_g_last[:, None])
116
+ b_h += b_hc
117
+
118
+ if STORE_FINAL_STATE:
119
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
120
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
121
+
122
+
123
+ def chunk_dplr_fwd_h(
124
+ kg: torch.Tensor,
125
+ v: torch.Tensor,
126
+ w: torch.Tensor,
127
+ u: torch.Tensor,
128
+ bg: torch.Tensor,
129
+ gk: torch.Tensor,
130
+ initial_state: Optional[torch.Tensor] = None,
131
+ output_final_state: bool = False,
132
+ offsets: Optional[torch.LongTensor] = None,
133
+ indices: Optional[torch.LongTensor] = None,
134
+ head_first: bool = True,
135
+ chunk_size: int = 64
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ if head_first:
138
+ B, H, T, K, V = *kg.shape, u.shape[-1]
139
+ else:
140
+ B, T, H, K, V = *kg.shape, u.shape[-1]
141
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
142
+ # N: the actual number of sequences in the batch with either equal or variable lengths
143
+ if offsets is None:
144
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
145
+ else:
146
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
147
+ BK = triton.next_power_of_2(K)
148
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
149
+ # H100 can have larger block size
150
+
151
+ if check_shared_mem('hopper', kg.device.index):
152
+ BV = 64
153
+ BC = 64 if K <= 128 else 32
154
+ elif check_shared_mem('ampere', kg.device.index): # A100
155
+ BV = 32
156
+ BC = 32
157
+ else:
158
+ BV = 16
159
+ BC = 16
160
+
161
+ BC = min(BT, BC)
162
+ NK = triton.cdiv(K, BK)
163
+ NV = triton.cdiv(V, BV)
164
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
165
+
166
+ if head_first:
167
+ h = kg.new_empty(B, H, NT, K, V)
168
+ else:
169
+ h = kg.new_empty(B, NT, H, K, V)
170
+ final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
171
+ v_new = torch.empty_like(u)
172
+ grid = (NK, NV, N * H)
173
+ chunk_dplr_fwd_kernel_h[grid](
174
+ kg=kg,
175
+ v=v,
176
+ w=w,
177
+ bg=bg,
178
+ u=u,
179
+ v_new=v_new,
180
+ h=h,
181
+ gk=gk,
182
+ h0=initial_state,
183
+ ht=final_state,
184
+ offsets=offsets,
185
+ chunk_offsets=chunk_offsets,
186
+ T=T,
187
+ H=H,
188
+ K=K,
189
+ V=V,
190
+ BT=BT,
191
+ BC=BC,
192
+ BK=BK,
193
+ BV=BV,
194
+ NT=NT,
195
+ HEAD_FIRST=head_first
196
+ )
197
+ return h, v_new, final_state
fla/ops/gla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (336 Bytes). View file
 
fla/ops/gla/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (81.8 kB). View file
 
fla/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (5.69 kB). View file
 
fla/ops/gla/chunk.py ADDED
@@ -0,0 +1,1486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils import chunk_local_cumsum
13
+ from fla.ops.utils.op import exp, safe_exp
14
+ from fla.utils import check_shared_mem, input_guard
15
+
16
+ BK_LIST = [32, 64] if check_shared_mem() else [16, 32]
17
+ BV_LIST = [64, 128] if check_shared_mem('ampere') else [16, 32]
18
+
19
+
20
+ @triton.heuristics({
21
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
22
+ })
23
+ @triton.autotune(
24
+ configs=[
25
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
26
+ for BK in [32, 64]
27
+ for num_warps in [1, 2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=["BC"]
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_gla_fwd_A_kernel_intra_sub_inter(
34
+ q,
35
+ k,
36
+ g,
37
+ A,
38
+ offsets,
39
+ indices,
40
+ scale,
41
+ T,
42
+ H: tl.constexpr,
43
+ K: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BC: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ NC: tl.constexpr,
48
+ USE_OFFSETS: tl.constexpr,
49
+ HEAD_FIRST: tl.constexpr
50
+ ):
51
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
52
+ i_b, i_h = i_bh // H, i_bh % H
53
+ i_i, i_j = i_c // NC, i_c % NC
54
+ if USE_OFFSETS:
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ else:
59
+ bos, eos = i_b * T, i_b * T + T
60
+
61
+ if i_t * BT + i_i * BC >= T:
62
+ return
63
+ if i_i <= i_j:
64
+ return
65
+
66
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
67
+ for i_k in range(tl.cdiv(K, BK)):
68
+ o_k = i_k * BK + tl.arange(0, BK)
69
+ m_k = o_k < K
70
+
71
+ if HEAD_FIRST:
72
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
73
+ p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
74
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
75
+ p_gk = tl.make_block_ptr(g + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
76
+ p_gn = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK)
77
+ else:
78
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
79
+ p_g = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
80
+ p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
81
+ p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
82
+ p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
83
+
84
+ # [BK,]
85
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
86
+ # [BC, BK]
87
+ b_q = tl.load(p_q, boundary_check=(0, 1))
88
+ b_g = tl.load(p_g, boundary_check=(0, 1))
89
+ b_qg = b_q * exp(b_g - b_gn[None, :]) * scale
90
+ # [BK, BC]
91
+ b_k = tl.load(p_k, boundary_check=(0, 1))
92
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
93
+ b_kg = b_k * exp(b_gn[:, None] - b_gk)
94
+ # [BC, BC] using tf32 to improve precision here.
95
+ b_A += tl.dot(b_qg, b_kg)
96
+
97
+ if HEAD_FIRST:
98
+ p_A = tl.make_block_ptr(A + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
99
+ else:
100
+ p_A = tl.make_block_ptr(A + (bos*H + i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
101
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
102
+
103
+
104
+ @triton.heuristics({
105
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
106
+ })
107
+ @triton.autotune(
108
+ configs=[
109
+ triton.Config({}, num_warps=1),
110
+ triton.Config({}, num_warps=2),
111
+ triton.Config({}, num_warps=4),
112
+ triton.Config({}, num_warps=8),
113
+ ],
114
+ key=["BK", "BT"]
115
+ )
116
+ @triton.jit(do_not_specialize=['T'])
117
+ def chunk_gla_fwd_A_kernel_intra_sub_intra(
118
+ q,
119
+ k,
120
+ g,
121
+ A,
122
+ offsets,
123
+ indices,
124
+ scale,
125
+ T,
126
+ H: tl.constexpr,
127
+ K: tl.constexpr,
128
+ BT: tl.constexpr,
129
+ BC: tl.constexpr,
130
+ BK: tl.constexpr,
131
+ USE_OFFSETS: tl.constexpr,
132
+ HEAD_FIRST: tl.constexpr
133
+ ):
134
+ i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
135
+ i_b, i_h = i_bh // H, i_bh % H
136
+ i_j = i_i
137
+ if USE_OFFSETS:
138
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
139
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
140
+ T = eos - bos
141
+ else:
142
+ bos, eos = i_b * T, i_b * T + T
143
+
144
+ if i_t * BT + i_i * BC >= T:
145
+ return
146
+
147
+ o_i = tl.arange(0, BC)
148
+ o_k = tl.arange(0, BK)
149
+ m_k = o_k < K
150
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
151
+ if HEAD_FIRST:
152
+ o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
153
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
154
+ p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
155
+ p_k = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK)
156
+ p_gk = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK)
157
+ else:
158
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC
159
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
160
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
161
+ p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k
162
+ p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k
163
+
164
+ b_q = tl.load(p_q, boundary_check=(0, 1))
165
+ b_g = tl.load(p_g, boundary_check=(0, 1))
166
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
167
+ b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32)
168
+ b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
169
+ b_A = tl.sum(b_q * b_k[None, :] * exp(b_g - b_gk[None, :]), 1)
170
+ b_A = tl.where(o_i >= j, b_A * scale, 0.)
171
+
172
+ tl.store(A + o_A + j, b_A, mask=m_A)
173
+ p_k += K if HEAD_FIRST else H*K
174
+ p_gk += K if HEAD_FIRST else H*K
175
+
176
+
177
+ @triton.heuristics({
178
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
179
+ })
180
+ @triton.autotune(
181
+ configs=[
182
+ triton.Config({}, num_warps=1),
183
+ triton.Config({}, num_warps=2),
184
+ triton.Config({}, num_warps=4),
185
+ triton.Config({}, num_warps=8),
186
+ ],
187
+ key=['BC', 'BK']
188
+ )
189
+ @triton.jit(do_not_specialize=['T'])
190
+ def chunk_gla_fwd_A_kernel_intra_sub_intra_split(
191
+ q,
192
+ k,
193
+ g,
194
+ A,
195
+ offsets,
196
+ indices,
197
+ scale,
198
+ T,
199
+ B: tl.constexpr,
200
+ H: tl.constexpr,
201
+ K: tl.constexpr,
202
+ BT: tl.constexpr,
203
+ BC: tl.constexpr,
204
+ BK: tl.constexpr,
205
+ NC: tl.constexpr,
206
+ USE_OFFSETS: tl.constexpr,
207
+ HEAD_FIRST: tl.constexpr
208
+ ):
209
+ i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
210
+ i_b, i_h = i_bh // H, i_bh % H
211
+ i_t, i_i = i_tc // NC, i_tc % NC
212
+ i_j = i_i
213
+ if USE_OFFSETS:
214
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
215
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
216
+ all = T
217
+ T = eos - bos
218
+ else:
219
+ bos, eos = i_b * T, i_b * T + T
220
+ all = B * T
221
+
222
+ if i_t * BT + i_i * BC >= T:
223
+ return
224
+
225
+ o_i = tl.arange(0, BC)
226
+ o_k = i_k * BK + tl.arange(0, BK)
227
+ m_k = o_k < K
228
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
229
+
230
+ if HEAD_FIRST:
231
+ o_A = (i_k * B*H + i_bh) * T * BC + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BC
232
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
233
+ p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
234
+ p_k = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK)
235
+ p_gk = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK)
236
+ else:
237
+ o_A = (i_k * all + bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC + i_h * BC
238
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
239
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
240
+ p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k
241
+ p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k
242
+
243
+ b_q = tl.load(p_q, boundary_check=(0, 1))
244
+ b_g = tl.load(p_g, boundary_check=(0, 1))
245
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
246
+ b_A = tl.zeros([BC], dtype=tl.float32)
247
+ b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32)
248
+ b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
249
+ b_A += tl.sum(b_q * b_k[None, :] * exp(b_g - b_gk[None, :]), 1)
250
+ b_A = tl.where(o_i >= j, b_A * scale, 0.)
251
+ tl.store(A + o_A + j, b_A, mask=m_A)
252
+ p_k += K if HEAD_FIRST else H*K
253
+ p_gk += K if HEAD_FIRST else H*K
254
+
255
+
256
+ @triton.heuristics({
257
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
258
+ })
259
+ @triton.autotune(
260
+ configs=[
261
+ triton.Config({}, num_warps=1),
262
+ triton.Config({}, num_warps=2),
263
+ triton.Config({}, num_warps=4),
264
+ triton.Config({}, num_warps=8),
265
+ ],
266
+ key=['BC']
267
+ )
268
+ @triton.jit(do_not_specialize=['T'])
269
+ def chunk_gla_fwd_A_kernel_intra_sub_intra_merge(
270
+ A,
271
+ A2,
272
+ offsets,
273
+ indices,
274
+ T,
275
+ B: tl.constexpr,
276
+ H: tl.constexpr,
277
+ BT: tl.constexpr,
278
+ BC: tl.constexpr,
279
+ NK: tl.constexpr,
280
+ USE_OFFSETS: tl.constexpr,
281
+ HEAD_FIRST: tl.constexpr
282
+ ):
283
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
284
+ i_b, i_h = i_bh // H, i_bh % H
285
+ if USE_OFFSETS:
286
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
287
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
288
+ all = T
289
+ T = eos - bos
290
+ else:
291
+ bos, eos = i_b * T, i_b * T + T
292
+ all = B * T
293
+
294
+ if i_t * BT + i_c * BC >= T:
295
+ return
296
+
297
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
298
+ for i_k in range(0, NK):
299
+ if HEAD_FIRST:
300
+ p_A = tl.make_block_ptr(A + (i_k*B*H+i_bh)*T*BC, (T, BC), (BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0))
301
+ else:
302
+ p_A = tl.make_block_ptr(A + (i_k*all+bos)*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0))
303
+ b_A += tl.load(p_A, boundary_check=(0, 1))
304
+ if HEAD_FIRST:
305
+ p_A2 = tl.make_block_ptr(A2 + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0))
306
+ else:
307
+ p_A2 = tl.make_block_ptr(A2 + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0))
308
+ tl.store(p_A2, b_A.to(A2.dtype.element_ty), boundary_check=(0, 1))
309
+
310
+
311
+ @triton.heuristics({
312
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
313
+ })
314
+ @triton.autotune(
315
+ configs=[
316
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps)
317
+ for BK in [32, 64]
318
+ for BV in [64, 128]
319
+ for num_warps in [2, 4, 8]
320
+ ],
321
+ key=['BT'],
322
+ )
323
+ @triton.jit(do_not_specialize=['T'])
324
+ def chunk_gla_fwd_kernel_o(
325
+ q,
326
+ v,
327
+ g,
328
+ h,
329
+ o,
330
+ A,
331
+ offsets,
332
+ indices,
333
+ scale,
334
+ T,
335
+ H: tl.constexpr,
336
+ K: tl.constexpr,
337
+ V: tl.constexpr,
338
+ BT: tl.constexpr,
339
+ BK: tl.constexpr,
340
+ BV: tl.constexpr,
341
+ USE_OFFSETS: tl.constexpr,
342
+ HEAD_FIRST: tl.constexpr
343
+ ):
344
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
345
+ i_b, i_h = i_bh // H, i_bh % H
346
+ if USE_OFFSETS:
347
+ i_tg = i_t
348
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
349
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
350
+ T = eos - bos
351
+ NT = tl.cdiv(T, BT)
352
+ else:
353
+ NT = tl.cdiv(T, BT)
354
+ i_tg = i_b * NT + i_t
355
+ bos, eos = i_b * T, i_b * T + T
356
+
357
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
358
+
359
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
360
+ for i_k in range(tl.cdiv(K, BK)):
361
+ if HEAD_FIRST:
362
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
363
+ p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
364
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
365
+ else:
366
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
367
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
368
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
369
+
370
+ # [BT, BK]
371
+ b_q = tl.load(p_q, boundary_check=(0, 1))
372
+ b_q = (b_q * scale).to(b_q.dtype)
373
+ # [BT, BK]
374
+ b_g = tl.load(p_g, boundary_check=(0, 1))
375
+ # [BT, BK]
376
+ b_qg = (b_q * exp(b_g)).to(b_q.dtype)
377
+ # [BK, BV]
378
+ b_h = tl.load(p_h, boundary_check=(0, 1))
379
+ # works but dkw, owing to divine benevolence
380
+ # [BT, BV]
381
+ if i_k >= 0:
382
+ b_o += tl.dot(b_qg, b_h.to(b_qg.dtype))
383
+ if HEAD_FIRST:
384
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
385
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
386
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
387
+ else:
388
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
389
+ p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
390
+ p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
391
+ # [BT, BV]
392
+ b_v = tl.load(p_v, boundary_check=(0, 1))
393
+ # [BT, BT]
394
+ b_A = tl.load(p_A, boundary_check=(0, 1))
395
+ b_A = tl.where(m_s, b_A, 0.).to(b_v.dtype)
396
+ b_o += tl.dot(b_A, b_v, allow_tf32=False)
397
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
398
+
399
+
400
+ @triton.heuristics({
401
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
402
+ })
403
+ @triton.autotune(
404
+ configs=[
405
+ triton.Config({}, num_warps=1),
406
+ triton.Config({}, num_warps=2),
407
+ triton.Config({}, num_warps=4),
408
+ triton.Config({}, num_warps=8),
409
+ ],
410
+ key=['BK', 'NC', 'BT'],
411
+ )
412
+ @triton.jit(do_not_specialize=['T'])
413
+ def chunk_gla_bwd_kernel_intra(
414
+ q,
415
+ k,
416
+ g,
417
+ dA,
418
+ dq,
419
+ dk,
420
+ offsets,
421
+ indices,
422
+ T,
423
+ H: tl.constexpr,
424
+ K: tl.constexpr,
425
+ BT: tl.constexpr,
426
+ BC: tl.constexpr,
427
+ BK: tl.constexpr,
428
+ NC: tl.constexpr,
429
+ USE_OFFSETS: tl.constexpr,
430
+ HEAD_FIRST: tl.constexpr
431
+ ):
432
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
433
+ i_b, i_h = i_bh // H, i_bh % H
434
+ i_t, i_i = i_c // NC, i_c % NC
435
+ if USE_OFFSETS:
436
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
437
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
438
+ else:
439
+ bos, eos = i_b * T, i_b * T + T
440
+ T = eos - bos
441
+ if i_t * BT + i_i * BC >= T:
442
+ return
443
+
444
+ o_k = i_k * BK + tl.arange(0, BK)
445
+ m_k = o_k < K
446
+
447
+ if HEAD_FIRST:
448
+ p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
449
+ else:
450
+ p_g = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
451
+ # [BC, BK]
452
+ b_g = tl.load(p_g, boundary_check=(0, 1))
453
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
454
+ if i_i > 0:
455
+ if HEAD_FIRST:
456
+ p_gn = g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k
457
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
458
+ else:
459
+ p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h*K + o_k
460
+
461
+ # [BK,]
462
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
463
+ for i_j in range(0, i_i):
464
+ if HEAD_FIRST:
465
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
466
+ p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
467
+ p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
468
+ else:
469
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0))
470
+ p_gk = tl.make_block_ptr(g+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0))
471
+ p_dA = tl.make_block_ptr(dA+(bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0))
472
+ # [BC, BK]
473
+ b_k = tl.load(p_k, boundary_check=(0, 1))
474
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
475
+ b_kg = (b_k * exp(b_gn[None, :] - b_gk))
476
+ # [BC, BC]
477
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
478
+ # [BC, BK]
479
+ b_dq += tl.dot(b_dA, b_kg)
480
+ b_dq *= exp(b_g - b_gn[None, :])
481
+
482
+ o_i = tl.arange(0, BC)
483
+ m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
484
+ if HEAD_FIRST:
485
+ o_dA = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
486
+ p_kj = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK)
487
+ p_gkj = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK)
488
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
489
+ else:
490
+ o_dA = bos*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC
491
+ p_kj = k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
492
+ p_gkj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
493
+ p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
494
+
495
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
496
+ # [BC,]
497
+ b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
498
+ # [BK,]
499
+ b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32)
500
+ b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32)
501
+ # [BC, BK]
502
+ m_i = o_i[:, None] >= j
503
+ # [BC, BK]
504
+ # (SY 09/17) important to not use bf16 here to have a good precision.
505
+ b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * exp(b_g - b_gkj[None, :]), 0.)
506
+ p_kj += K if HEAD_FIRST else H*K
507
+ p_gkj += K if HEAD_FIRST else H*K
508
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
509
+
510
+ tl.debug_barrier()
511
+ if HEAD_FIRST:
512
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
513
+ p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
514
+ else:
515
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
516
+ p_gk = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
517
+
518
+ # [BC, BK]
519
+ b_k = tl.load(p_k, boundary_check=(0, 1))
520
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
521
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
522
+
523
+ NC = min(NC, tl.cdiv(T - i_t * BT, BC))
524
+ if i_i < NC - 1:
525
+ if HEAD_FIRST:
526
+ p_gn = g + (i_bh * T + min(i_t * BT + i_i * BC + BC, T) - 1) * K + o_k
527
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
528
+ else:
529
+ p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T) - 1) * H*K + i_h * K + o_k
530
+
531
+ # [BK,]
532
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
533
+ for i_j in range(i_i + 1, NC):
534
+ if HEAD_FIRST:
535
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t*BT + i_j*BC, i_k*BK), (BC, BK), (1, 0))
536
+ p_gq = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t*BT + i_j*BC, i_k*BK), (BC, BK), (1, 0))
537
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (BT, T), (1, BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
538
+ else:
539
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0))
540
+ p_gq = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0))
541
+ p_dA = tl.make_block_ptr(dA + (bos*H+i_h)*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
542
+ # [BC, BK]
543
+ b_q = tl.load(p_q, boundary_check=(0, 1))
544
+ b_gq = tl.load(p_gq, boundary_check=(0, 1))
545
+ b_qg = b_q * safe_exp(b_gq - b_gn[None, :])
546
+ # [BC, BC]
547
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
548
+ # [BC, BK]
549
+ # (SY 09/17) important to not use bf16 here to have a good precision.
550
+ b_dk += tl.dot(b_dA, b_qg)
551
+ b_dk *= exp(b_gn[None, :] - b_gk)
552
+ if HEAD_FIRST:
553
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
554
+ p_qj = tl.max_contiguous(tl.multiple_of(q + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK)
555
+ p_gqj = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK)
556
+ p_dk = tl.make_block_ptr(dk + i_bh*T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
557
+ else:
558
+ o_dA = bos*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC)
559
+ p_qj = q + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
560
+ p_gqj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k
561
+ p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
562
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
563
+ # [BC,]
564
+ b_dA = tl.load(dA + o_dA + j * (1 if HEAD_FIRST else H) * BT)
565
+ # [BK,]
566
+ b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32)
567
+ b_gqj = tl.load(p_gqj, mask=m_k, other=0).to(tl.float32)
568
+ # [BC, BK]
569
+ m_i = o_i[:, None] <= j
570
+ b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_gqj[None, :] - b_gk), 0.)
571
+ p_qj += K if HEAD_FIRST else H*K
572
+ p_gqj += K if HEAD_FIRST else H*K
573
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
574
+
575
+
576
+ @triton.heuristics({
577
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
578
+ })
579
+ @triton.autotune(
580
+ configs=[
581
+ triton.Config({}, num_warps=1),
582
+ triton.Config({}, num_warps=2),
583
+ triton.Config({}, num_warps=4),
584
+ triton.Config({}, num_warps=8),
585
+ ],
586
+ key=['BV', 'BT'],
587
+ )
588
+ @triton.jit(do_not_specialize=['T'])
589
+ def chunk_gla_bwd_kernel_dA(
590
+ v,
591
+ do,
592
+ dA,
593
+ offsets,
594
+ indices,
595
+ scale,
596
+ T,
597
+ H: tl.constexpr,
598
+ V: tl.constexpr,
599
+ BT: tl.constexpr,
600
+ BV: tl.constexpr,
601
+ USE_OFFSETS: tl.constexpr,
602
+ HEAD_FIRST: tl.constexpr
603
+ ):
604
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
605
+ i_b, i_h = i_bh // H, i_bh % H
606
+ if USE_OFFSETS:
607
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
608
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
609
+ else:
610
+ bos, eos = i_b * T, i_b * T + T
611
+ T = eos - bos
612
+
613
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
614
+ for i_v in range(tl.cdiv(V, BV)):
615
+ if HEAD_FIRST:
616
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
617
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
618
+ else:
619
+ p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
620
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1))
621
+ b_v = tl.load(p_v, boundary_check=(0, 1))
622
+ b_do = tl.load(p_do, boundary_check=(0, 1))
623
+ b_dA += tl.dot(b_do, b_v)
624
+ if HEAD_FIRST:
625
+ p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
626
+ else:
627
+ p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
628
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
629
+ b_dA = tl.where(m_s, b_dA * scale, 0.)
630
+ tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
631
+
632
+
633
+ @triton.heuristics({
634
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
635
+ })
636
+ @triton.autotune(
637
+ configs=[
638
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps)
639
+ for BK in BK_LIST
640
+ for BV in BV_LIST
641
+ for num_warps in [2, 4, 8]
642
+ ],
643
+ key=['BT'],
644
+ )
645
+ @triton.jit(do_not_specialize=['T'])
646
+ def chunk_gla_bwd_kernel_dv(
647
+ k,
648
+ g,
649
+ A,
650
+ do,
651
+ dh,
652
+ dv,
653
+ offsets,
654
+ indices,
655
+ T,
656
+ H: tl.constexpr,
657
+ K: tl.constexpr,
658
+ V: tl.constexpr,
659
+ BT: tl.constexpr,
660
+ BK: tl.constexpr,
661
+ BV: tl.constexpr,
662
+ USE_OFFSETS: tl.constexpr,
663
+ HEAD_FIRST: tl.constexpr
664
+ ):
665
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
666
+ i_b, i_h = i_bh // H, i_bh % H
667
+ if USE_OFFSETS:
668
+ i_tg = i_t
669
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
670
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
671
+ T = eos - bos
672
+ NT = tl.cdiv(T, BT)
673
+ else:
674
+ NT = tl.cdiv(T, BT)
675
+ i_tg = i_b * NT + i_t
676
+ bos, eos = i_b * T, i_b * T + T
677
+
678
+ if HEAD_FIRST:
679
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
680
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
681
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
682
+ else:
683
+ p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
684
+ p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
685
+ p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
686
+
687
+ b_A = tl.load(p_A, boundary_check=(0, 1))
688
+ b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0.)
689
+ b_do = tl.load(p_do, boundary_check=(0, 1))
690
+ # (SY 09/17) important to disallow tf32 here to maintain a good precision.
691
+ b_dv = tl.dot(b_A, b_do.to(b_A.dtype), allow_tf32=False)
692
+
693
+ for i_k in range(tl.cdiv(K, BK)):
694
+ o_k = i_k * BK + tl.arange(0, BK)
695
+ m_k = o_k < K
696
+
697
+ if HEAD_FIRST:
698
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
699
+ p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
700
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + min(i_t * BT + BT, T) * K - K + o_k, BK), BK)
701
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
702
+ else:
703
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
704
+ p_gk = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
705
+ p_gn = g + (bos + min(i_t * BT + BT, T) - 1)*H*K + i_h * K + o_k
706
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
707
+
708
+ b_k = tl.load(p_k, boundary_check=(0, 1))
709
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
710
+ b_gn = exp(tl.load(p_gn, mask=m_k, other=0)[None, :] - b_gk)
711
+ b_k = (b_k * b_gn).to(b_k.dtype)
712
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
713
+ # [BT, BV]
714
+ # (SY 09/17) it is ok to have bf16 interchunk gradient contribution here
715
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype))
716
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
717
+
718
+
719
+ @triton.heuristics({
720
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
721
+ })
722
+ @triton.autotune(
723
+ configs=[
724
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps)
725
+ for BK in BK_LIST
726
+ for BV in BV_LIST
727
+ for num_warps in [2, 4, 8]
728
+ ],
729
+ key=['BT'],
730
+ )
731
+ @triton.jit(do_not_specialize=['T'])
732
+ def chunk_gla_bwd_kernel_inter(
733
+ q,
734
+ k,
735
+ v,
736
+ h,
737
+ g,
738
+ do,
739
+ dh,
740
+ dq,
741
+ dk,
742
+ dq2,
743
+ dk2,
744
+ dg,
745
+ offsets,
746
+ indices,
747
+ scale,
748
+ T,
749
+ H: tl.constexpr,
750
+ K: tl.constexpr,
751
+ V: tl.constexpr,
752
+ BT: tl.constexpr,
753
+ BK: tl.constexpr,
754
+ BV: tl.constexpr,
755
+ USE_OFFSETS: tl.constexpr,
756
+ HEAD_FIRST: tl.constexpr
757
+ ):
758
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
759
+ i_b, i_h = i_bh // H, i_bh % H
760
+ if USE_OFFSETS:
761
+ i_tg = i_t
762
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
763
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
764
+ T = eos - bos
765
+ NT = tl.cdiv(T, BT)
766
+ else:
767
+ NT = tl.cdiv(T, BT)
768
+ i_tg = i_b * NT + i_t
769
+ bos, eos = i_b * T, i_b * T + T
770
+ o_k = i_k * BK + tl.arange(0, BK)
771
+ m_k = o_k < K
772
+
773
+ if HEAD_FIRST:
774
+ p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
775
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + (min(T, i_t * BT + BT)-1) * K + o_k, BK), BK)
776
+ else:
777
+ p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
778
+ p_gn = g + (bos + min(T, i_t * BT + BT)-1) * H*K + i_h * K + o_k
779
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
780
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
781
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
782
+ b_dgk = tl.zeros([BK,], dtype=tl.float32)
783
+
784
+ for i_v in range(tl.cdiv(V, BV)):
785
+ if HEAD_FIRST:
786
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
787
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
788
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
789
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
790
+ else:
791
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
792
+ p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
793
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
794
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
795
+ # [BT, BV]
796
+ b_v = tl.load(p_v, boundary_check=(0, 1))
797
+ b_do = tl.load(p_do, boundary_check=(0, 1))
798
+ # [BV, BK]
799
+ b_h = tl.load(p_h, boundary_check=(0, 1))
800
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
801
+ # [BK]
802
+ b_dgk += tl.sum(b_h * b_dh, axis=0)
803
+ # [BT, BK]
804
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
805
+ b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
806
+ b_dgk *= exp(b_gn)
807
+ b_dq *= scale
808
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
809
+ b_dq = b_dq * exp(b_gk)
810
+ b_dk = b_dk * exp(b_gn[None, :] - b_gk)
811
+
812
+ if HEAD_FIRST:
813
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
814
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
815
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
816
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
817
+ else:
818
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
819
+ p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
820
+ p_dq = tl.make_block_ptr(dq + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
821
+ p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
822
+ b_q = tl.load(p_q, boundary_check=(0, 1))
823
+ b_k = tl.load(p_k, boundary_check=(0, 1))
824
+ b_dgk += tl.sum(b_dk * b_k, axis=0)
825
+ b_dq += tl.load(p_dq, boundary_check=(0, 1))
826
+ b_dk += tl.load(p_dk, boundary_check=(0, 1))
827
+ b_dg = b_q * b_dq - b_k * b_dk
828
+ # tl.debug_barrier()
829
+ b_dg = b_dg - tl.cumsum(b_dg, axis=0) + tl.sum(b_dg, axis=0)[None, :] + b_dgk[None, :]
830
+ # Buggy due to strange triton compiler issue.
831
+ # m_s = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], 1., 0.)
832
+ # b_dg = tl.dot(m_s, b_dg, allow_tf32=False) + b_dgk[None, :]
833
+ if HEAD_FIRST:
834
+ p_dq = tl.make_block_ptr(dq2 + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
835
+ p_dk = tl.make_block_ptr(dk2 + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
836
+ p_dg = tl.make_block_ptr(dg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
837
+ else:
838
+ p_dq = tl.make_block_ptr(dq2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
839
+ p_dk = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
840
+ p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
841
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
842
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
843
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
844
+
845
+
846
+ def chunk_gla_fwd_intra_gk(
847
+ q: torch.Tensor,
848
+ k: torch.Tensor,
849
+ g: torch.Tensor,
850
+ scale: float,
851
+ offsets: Optional[torch.LongTensor] = None,
852
+ indices: Optional[torch.LongTensor] = None,
853
+ head_first: bool = True,
854
+ chunk_size: int = 64
855
+ ):
856
+ if head_first:
857
+ B, H, T, K = k.shape
858
+ else:
859
+ B, T, H, K = k.shape
860
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
861
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
862
+ BC = min(16, BT)
863
+ NC = triton.cdiv(BT, BC)
864
+
865
+ A = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
866
+ grid = (NT, NC * NC, B * H)
867
+ chunk_gla_fwd_A_kernel_intra_sub_inter[grid](
868
+ q,
869
+ k,
870
+ g,
871
+ A,
872
+ offsets,
873
+ indices,
874
+ scale,
875
+ T=T,
876
+ H=H,
877
+ K=K,
878
+ BT=BT,
879
+ BC=BC,
880
+ NC=NC,
881
+ HEAD_FIRST=head_first
882
+ )
883
+
884
+ grid = (NT, NC, B * H)
885
+ # load the entire [BC, K] blocks into SRAM at once
886
+ if K <= 256:
887
+ BK = triton.next_power_of_2(K)
888
+ chunk_gla_fwd_A_kernel_intra_sub_intra[grid](
889
+ q,
890
+ k,
891
+ g,
892
+ A,
893
+ offsets,
894
+ indices,
895
+ scale,
896
+ T=T,
897
+ H=H,
898
+ K=K,
899
+ BT=BT,
900
+ BC=BC,
901
+ BK=BK,
902
+ HEAD_FIRST=head_first
903
+ )
904
+ # split then merge
905
+ else:
906
+ BK = min(128, triton.next_power_of_2(K))
907
+ NK = triton.cdiv(K, BK)
908
+ A_intra = q.new_empty(NK, B, *((H, T) if head_first else (T, H)), BC, dtype=torch.float)
909
+
910
+ grid = (NK, NT * NC, B * H)
911
+ chunk_gla_fwd_A_kernel_intra_sub_intra_split[grid](
912
+ q,
913
+ k,
914
+ g,
915
+ A_intra,
916
+ offsets,
917
+ indices,
918
+ scale,
919
+ T=T,
920
+ B=B,
921
+ H=H,
922
+ K=K,
923
+ BT=BT,
924
+ BC=BC,
925
+ BK=BK,
926
+ NC=NC,
927
+ HEAD_FIRST=head_first
928
+ )
929
+
930
+ grid = (NT, NC, B * H)
931
+ chunk_gla_fwd_A_kernel_intra_sub_intra_merge[grid](
932
+ A_intra,
933
+ A,
934
+ offsets,
935
+ indices,
936
+ T=T,
937
+ B=B,
938
+ H=H,
939
+ BT=BT,
940
+ BC=BC,
941
+ NK=NK,
942
+ HEAD_FIRST=head_first
943
+ )
944
+ return A
945
+
946
+
947
+ def chunk_gla_fwd_o_gk(
948
+ q: torch.Tensor,
949
+ v: torch.Tensor,
950
+ g: torch.Tensor,
951
+ A: torch.Tensor,
952
+ h: torch.Tensor,
953
+ scale: float,
954
+ offsets: Optional[torch.LongTensor] = None,
955
+ indices: Optional[torch.LongTensor] = None,
956
+ head_first: bool = True,
957
+ chunk_size: int = 64
958
+ ):
959
+ if head_first:
960
+ B, H, T, K, V = *q.shape, v.shape[-1]
961
+ else:
962
+ B, T, H, K, V = *q.shape, v.shape[-1]
963
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
964
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
965
+
966
+ o = torch.empty_like(v)
967
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
968
+ chunk_gla_fwd_kernel_o[grid](
969
+ q,
970
+ v,
971
+ g,
972
+ h,
973
+ o,
974
+ A,
975
+ offsets,
976
+ indices,
977
+ scale,
978
+ T=T,
979
+ H=H,
980
+ K=K,
981
+ V=V,
982
+ BT=BT,
983
+ HEAD_FIRST=head_first
984
+ )
985
+ return o
986
+
987
+
988
+ def chunk_gla_bwd_dA(
989
+ v: torch.Tensor,
990
+ do: torch.Tensor,
991
+ scale: float,
992
+ offsets: Optional[torch.LongTensor] = None,
993
+ indices: Optional[torch.LongTensor] = None,
994
+ head_first: bool = True,
995
+ chunk_size: int = 64
996
+ ):
997
+ if head_first:
998
+ B, H, T, V = v.shape
999
+ else:
1000
+ B, T, H, V = v.shape
1001
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
1002
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
1003
+ BV = min(64, triton.next_power_of_2(V))
1004
+
1005
+ dA = v.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
1006
+ grid = (NT, B * H)
1007
+ chunk_gla_bwd_kernel_dA[grid](
1008
+ v,
1009
+ do,
1010
+ dA,
1011
+ offsets,
1012
+ indices,
1013
+ scale,
1014
+ T=T,
1015
+ H=H,
1016
+ V=V,
1017
+ BT=BT,
1018
+ BV=BV,
1019
+ HEAD_FIRST=head_first
1020
+ )
1021
+ return dA
1022
+
1023
+
1024
+ def chunk_gla_bwd_dv(
1025
+ k: torch.Tensor,
1026
+ g: torch.Tensor,
1027
+ A: torch.Tensor,
1028
+ do: torch.Tensor,
1029
+ dh: torch.Tensor,
1030
+ offsets: Optional[torch.LongTensor] = None,
1031
+ indices: Optional[torch.LongTensor] = None,
1032
+ head_first: bool = True,
1033
+ chunk_size: int = 64
1034
+ ):
1035
+ if head_first:
1036
+ B, H, T, K, V = *k.shape, do.shape[-1]
1037
+ else:
1038
+ B, T, H, K, V = *k.shape, do.shape[-1]
1039
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
1040
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
1041
+
1042
+ dv = torch.empty_like(do)
1043
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
1044
+ chunk_gla_bwd_kernel_dv[grid](
1045
+ k,
1046
+ g,
1047
+ A,
1048
+ do,
1049
+ dh,
1050
+ dv,
1051
+ offsets,
1052
+ indices,
1053
+ T=T,
1054
+ H=H,
1055
+ K=K,
1056
+ V=V,
1057
+ BT=BT,
1058
+ HEAD_FIRST=head_first
1059
+ )
1060
+ return dv
1061
+
1062
+
1063
+ def chunk_gla_bwd_dqk_intra(
1064
+ q: torch.Tensor,
1065
+ k: torch.Tensor,
1066
+ g: torch.Tensor,
1067
+ dA: torch.Tensor,
1068
+ offsets: Optional[torch.LongTensor] = None,
1069
+ indices: Optional[torch.LongTensor] = None,
1070
+ head_first: bool = True,
1071
+ chunk_size: int = 64
1072
+ ):
1073
+ if head_first:
1074
+ B, H, T, K = q.shape
1075
+ else:
1076
+ B, T, H, K = q.shape
1077
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
1078
+ BC = min(16, BT)
1079
+ BK = min(64, triton.next_power_of_2(K))
1080
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
1081
+ NC = triton.cdiv(BT, BC)
1082
+ NK = triton.cdiv(K, BK)
1083
+
1084
+ dq = torch.empty_like(q, dtype=torch.float)
1085
+ dk = torch.empty_like(k, dtype=torch.float)
1086
+ grid = (NK, NT * NC, B * H)
1087
+ chunk_gla_bwd_kernel_intra[grid](
1088
+ q,
1089
+ k,
1090
+ g,
1091
+ dA,
1092
+ dq,
1093
+ dk,
1094
+ offsets,
1095
+ indices,
1096
+ T=T,
1097
+ H=H,
1098
+ K=K,
1099
+ BT=BT,
1100
+ BC=BC,
1101
+ BK=BK,
1102
+ NC=NC,
1103
+ HEAD_FIRST=head_first
1104
+ )
1105
+ return dq, dk
1106
+
1107
+
1108
+ def chunk_gla_bwd_dqkg(
1109
+ q: torch.Tensor,
1110
+ k: torch.Tensor,
1111
+ v: torch.Tensor,
1112
+ h: torch.Tensor,
1113
+ g: torch.Tensor,
1114
+ do: torch.Tensor,
1115
+ dh: torch.Tensor,
1116
+ dq: torch.Tensor,
1117
+ dk: torch.Tensor,
1118
+ scale: float,
1119
+ offsets: Optional[torch.LongTensor] = None,
1120
+ indices: Optional[torch.LongTensor] = None,
1121
+ head_first: bool = True,
1122
+ chunk_size: int = 64
1123
+ ):
1124
+ if head_first:
1125
+ B, H, T, K, V = *k.shape, v.shape[-1]
1126
+ else:
1127
+ B, T, H, K, V = *k.shape, v.shape[-1]
1128
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
1129
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
1130
+
1131
+ dg = torch.empty_like(g)
1132
+ # work around triton compiler bugs.
1133
+ dq2 = torch.empty_like(dq)
1134
+ dk2 = torch.empty_like(dk)
1135
+ def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H)
1136
+ chunk_gla_bwd_kernel_inter[grid](
1137
+ q,
1138
+ k,
1139
+ v,
1140
+ h,
1141
+ g,
1142
+ do,
1143
+ dh,
1144
+ dq,
1145
+ dk,
1146
+ dq2,
1147
+ dk2,
1148
+ dg,
1149
+ offsets,
1150
+ indices,
1151
+ scale,
1152
+ T=T,
1153
+ H=H,
1154
+ K=K,
1155
+ V=V,
1156
+ BT=BT,
1157
+ HEAD_FIRST=head_first
1158
+ )
1159
+ return dq2, dk2, dg
1160
+
1161
+
1162
+ def chunk_gla_fwd(
1163
+ q: torch.Tensor,
1164
+ k: torch.Tensor,
1165
+ v: torch.Tensor,
1166
+ g: torch.Tensor,
1167
+ g_cumsum: Optional[torch.Tensor],
1168
+ scale: float,
1169
+ initial_state: torch.Tensor,
1170
+ output_final_state: bool,
1171
+ offsets: Optional[torch.LongTensor] = None,
1172
+ indices: Optional[torch.LongTensor] = None,
1173
+ head_first: bool = True,
1174
+ chunk_size: int = 64
1175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1176
+ T = q.shape[2] if head_first else q.shape[1]
1177
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
1178
+ if g_cumsum is None:
1179
+ g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets, indices=indices, head_first=head_first)
1180
+
1181
+ h, ht = chunk_fwd_h(
1182
+ k=k,
1183
+ v=v,
1184
+ g=None,
1185
+ gk=g_cumsum,
1186
+ gv=None,
1187
+ h0=initial_state,
1188
+ output_final_state=output_final_state,
1189
+ states_in_fp32=False,
1190
+ offsets=offsets,
1191
+ head_first=head_first,
1192
+ chunk_size=BT
1193
+ )
1194
+
1195
+ # the intra A is kept in fp32
1196
+ # the computation has very marginal effect on the entire throughput
1197
+ A = chunk_gla_fwd_intra_gk(
1198
+ q=q,
1199
+ k=k,
1200
+ g=g_cumsum,
1201
+ scale=scale,
1202
+ offsets=offsets,
1203
+ indices=indices,
1204
+ head_first=head_first,
1205
+ chunk_size=BT
1206
+ )
1207
+ o = chunk_gla_fwd_o_gk(
1208
+ q=q,
1209
+ v=v,
1210
+ g=g_cumsum,
1211
+ A=A,
1212
+ h=h,
1213
+ scale=scale,
1214
+ offsets=offsets,
1215
+ indices=indices,
1216
+ head_first=head_first,
1217
+ chunk_size=BT
1218
+ )
1219
+ return g_cumsum, A, h, ht, o
1220
+
1221
+
1222
+ def chunk_gla_bwd(
1223
+ q: torch.Tensor,
1224
+ k: torch.Tensor,
1225
+ v: torch.Tensor,
1226
+ g: torch.Tensor,
1227
+ g_cumsum: Optional[torch.Tensor],
1228
+ scale: float,
1229
+ initial_state: torch.Tensor,
1230
+ h: torch.Tensor,
1231
+ A: torch.Tensor,
1232
+ do: torch.Tensor,
1233
+ dht: torch.Tensor,
1234
+ offsets: Optional[torch.LongTensor] = None,
1235
+ indices: Optional[torch.LongTensor] = None,
1236
+ head_first: bool = True,
1237
+ chunk_size: int = 64
1238
+ ):
1239
+ T = q.shape[2] if head_first else q.shape[1]
1240
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
1241
+ if g_cumsum is None:
1242
+ g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets, indices=indices, head_first=head_first)
1243
+
1244
+ if h is None:
1245
+ h, _ = chunk_fwd_h(
1246
+ k=k,
1247
+ v=v,
1248
+ g=None,
1249
+ gk=g_cumsum,
1250
+ gv=None,
1251
+ h0=initial_state,
1252
+ output_final_state=False,
1253
+ offsets=offsets,
1254
+ head_first=head_first,
1255
+ chunk_size=BT,
1256
+ states_in_fp32=True
1257
+ )
1258
+ dh, dh0 = chunk_bwd_dh(
1259
+ q=q,
1260
+ k=k,
1261
+ v=v,
1262
+ g=None,
1263
+ gk=g_cumsum,
1264
+ gv=None,
1265
+ do=do,
1266
+ h0=initial_state,
1267
+ dht=dht,
1268
+ scale=scale,
1269
+ offsets=offsets,
1270
+ head_first=head_first,
1271
+ chunk_size=BT,
1272
+ states_in_fp32=True
1273
+ )
1274
+
1275
+ dv = chunk_gla_bwd_dv(
1276
+ k=k,
1277
+ g=g_cumsum,
1278
+ A=A,
1279
+ do=do,
1280
+ dh=dh,
1281
+ offsets=offsets,
1282
+ indices=indices,
1283
+ head_first=head_first,
1284
+ chunk_size=BT
1285
+ )
1286
+
1287
+ # dq dk in fp32
1288
+ dA = chunk_gla_bwd_dA(
1289
+ v=v,
1290
+ do=do,
1291
+ scale=scale,
1292
+ offsets=offsets,
1293
+ indices=indices,
1294
+ head_first=head_first,
1295
+ chunk_size=BT
1296
+ )
1297
+ dq, dk = chunk_gla_bwd_dqk_intra(
1298
+ q=q,
1299
+ k=k,
1300
+ g=g_cumsum,
1301
+ dA=dA,
1302
+ offsets=offsets,
1303
+ indices=indices,
1304
+ head_first=head_first,
1305
+ chunk_size=BT
1306
+ )
1307
+ dq, dk, dg = chunk_gla_bwd_dqkg(
1308
+ q=q,
1309
+ k=k,
1310
+ v=v,
1311
+ h=h,
1312
+ g=g_cumsum,
1313
+ do=do,
1314
+ dh=dh,
1315
+ dq=dq,
1316
+ dk=dk,
1317
+ scale=scale,
1318
+ offsets=offsets,
1319
+ indices=indices,
1320
+ head_first=head_first,
1321
+ chunk_size=BT
1322
+ )
1323
+ return dq, dk, dv, dg, dh0
1324
+
1325
+
1326
+ class ChunkGLAFunction(torch.autograd.Function):
1327
+
1328
+ @staticmethod
1329
+ @input_guard
1330
+ def forward(
1331
+ ctx,
1332
+ q,
1333
+ k,
1334
+ v,
1335
+ g,
1336
+ scale,
1337
+ initial_state,
1338
+ output_final_state,
1339
+ offsets,
1340
+ head_first
1341
+ ):
1342
+ T = q.shape[2] if head_first else q.shape[1]
1343
+ chunk_size = min(64, max(16, triton.next_power_of_2(T)))
1344
+
1345
+ # 2-d indices denoting the offsets of chunks in each sequence
1346
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
1347
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
1348
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
1349
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
1350
+ g_cumsum, A, h, ht, o = chunk_gla_fwd(
1351
+ q=q,
1352
+ k=k,
1353
+ v=v,
1354
+ g=g,
1355
+ g_cumsum=None,
1356
+ scale=scale,
1357
+ initial_state=initial_state,
1358
+ output_final_state=output_final_state,
1359
+ offsets=offsets,
1360
+ indices=indices,
1361
+ head_first=head_first,
1362
+ chunk_size=chunk_size
1363
+ )
1364
+ # recompute g_cumsum in bwd pass
1365
+ if g.dtype != torch.float:
1366
+ g_cumsum = None
1367
+ else:
1368
+ g = None
1369
+ ctx.save_for_backward(q, k, v, g, g_cumsum, initial_state, A)
1370
+ ctx.chunk_size = chunk_size
1371
+ ctx.scale = scale
1372
+ ctx.offsets = offsets
1373
+ ctx.indices = indices
1374
+ ctx.head_first = head_first
1375
+ return o, ht
1376
+
1377
+ @staticmethod
1378
+ @input_guard
1379
+ def backward(ctx, do, dht):
1380
+ q, k, v, g, g_cumsum, initial_state, A = ctx.saved_tensors
1381
+ chunk_size, scale, offsets, indices, head_first = ctx.chunk_size, ctx.scale, ctx.offsets, ctx.indices, ctx.head_first
1382
+ dq, dk, dv, dg, dh0 = chunk_gla_bwd(
1383
+ q=q,
1384
+ k=k,
1385
+ v=v,
1386
+ g=g,
1387
+ g_cumsum=g_cumsum,
1388
+ scale=scale,
1389
+ h=None,
1390
+ A=A,
1391
+ initial_state=initial_state,
1392
+ do=do,
1393
+ dht=dht,
1394
+ offsets=offsets,
1395
+ indices=indices,
1396
+ head_first=head_first,
1397
+ chunk_size=chunk_size
1398
+ )
1399
+ return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None, None
1400
+
1401
+
1402
+ @torch.compiler.disable
1403
+ def chunk_gla(
1404
+ q: torch.Tensor,
1405
+ k: torch.Tensor,
1406
+ v: torch.Tensor,
1407
+ g: torch.Tensor,
1408
+ scale: Optional[int] = None,
1409
+ initial_state: torch.Tensor = None,
1410
+ output_final_state: bool = False,
1411
+ cu_seqlens: Optional[torch.LongTensor] = None,
1412
+ head_first: bool = True
1413
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1414
+ r"""
1415
+ Args:
1416
+ q (torch.Tensor):
1417
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
1418
+ k (torch.Tensor):
1419
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
1420
+ v (torch.Tensor):
1421
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1422
+ g (torch.Tensor):
1423
+ Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys.
1424
+ scale (Optional[int]):
1425
+ Scale factor for the attention scores.
1426
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1427
+ initial_state (Optional[torch.Tensor]):
1428
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
1429
+ For equal-length input sequences, `N` equals the batch size `B`.
1430
+ Default: `None`.
1431
+ output_final_state (Optional[bool]):
1432
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
1433
+ cu_seqlens (torch.LongTensor):
1434
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1435
+ consistent with the FlashAttention API.
1436
+ head_first (Optional[bool]):
1437
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
1438
+ Default: `True`.
1439
+
1440
+ Returns:
1441
+ o (torch.Tensor):
1442
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1443
+ final_state (torch.Tensor):
1444
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
1445
+
1446
+ Examples::
1447
+ >>> import torch
1448
+ >>> import torch.nn.functional as F
1449
+ >>> from einops import rearrange
1450
+ >>> from fla.ops.gla import chunk_gla
1451
+ # inputs with equal lengths
1452
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
1453
+ >>> q = torch.randn(B, T, H, K, device='cuda')
1454
+ >>> k = torch.randn(B, T, H, K, device='cuda')
1455
+ >>> v = torch.randn(B, T, H, V, device='cuda')
1456
+ >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda'))
1457
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
1458
+ >>> o, ht = chunk_gla(q, k, v, g,
1459
+ initial_state=h0,
1460
+ output_final_state=True,
1461
+ head_first=False)
1462
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
1463
+ >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g))
1464
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
1465
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
1466
+ >>> o_var, ht_var = chunk_gla(q, k, v, g,
1467
+ initial_state=h0,
1468
+ output_final_state=True,
1469
+ cu_seqlens=cu_seqlens,
1470
+ head_first=False)
1471
+ >>> assert o.allclose(o_var.view(o.shape))
1472
+ >>> assert ht.allclose(ht_var)
1473
+ """
1474
+ if cu_seqlens is not None:
1475
+ if q.shape[0] != 1:
1476
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
1477
+ f"Please flatten variable-length inputs before processing.")
1478
+ if head_first:
1479
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
1480
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
1481
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
1482
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
1483
+ if scale is None:
1484
+ scale = q.shape[-1] ** -0.5
1485
+ o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens, head_first)
1486
+ return o, final_state
fla/ops/gsa/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (69.4 kB). View file
 
fla/ops/hgrn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (288 Bytes). View file
 
fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (14.3 kB). View file
 
fla/ops/lightning_attn/chunk.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.simple_gla.chunk import chunk_simple_gla
9
+
10
+
11
+ @torch.compiler.disable
12
+ def chunk_lightning_attn(
13
+ q: torch.Tensor,
14
+ k: torch.Tensor,
15
+ v: torch.Tensor,
16
+ layer_idx: int,
17
+ num_layers: int,
18
+ scale: Optional[float] = None,
19
+ initial_state: Optional[torch.Tensor] = None,
20
+ output_final_state: bool = False,
21
+ cu_seqlens: Optional[torch.LongTensor] = None,
22
+ head_first: bool = True
23
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ r"""
25
+ Args:
26
+ q (torch.Tensor):
27
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
28
+ k (torch.Tensor):
29
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
30
+ v (torch.Tensor):
31
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
32
+ layer_idx (int):
33
+ The index of the current layer.
34
+ num_layers (int):
35
+ The total number of layers. Both `layer_idx` and `num_layers` are used to compute the decay factor.
36
+ scale (Optional[int]):
37
+ Scale factor for the attention scores.
38
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
39
+ initial_state (Optional[torch.Tensor]):
40
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
41
+ For equal-length input sequences, `N` equals the batch size `B`.
42
+ Default: `None`.
43
+ output_final_state (Optional[bool]):
44
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
45
+ cu_seqlens (torch.LongTensor):
46
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
47
+ consistent with the FlashAttention API.
48
+ head_first (Optional[bool]):
49
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
50
+ Default: `True`.
51
+
52
+ Returns:
53
+ o (torch.Tensor):
54
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
55
+ final_state (torch.Tensor):
56
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
57
+ """
58
+ H = q.shape[1] if head_first else q.shape[2]
59
+ s = -(8 / H * (1 - layer_idx / num_layers)) * q.new_tensor(range(H), dtype=torch.float)
60
+ if head_first:
61
+ g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
62
+ else:
63
+ g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
64
+ return chunk_simple_gla(
65
+ q=q,
66
+ k=k,
67
+ v=v,
68
+ scale=scale,
69
+ g=g,
70
+ initial_state=initial_state,
71
+ output_final_state=output_final_state,
72
+ head_first=head_first,
73
+ cu_seqlens=cu_seqlens
74
+ )