YongganFu commited on
Commit
fb7c603
·
verified ·
1 Parent(s): 9afe10b

Upload FastSLMForCausalLM

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "FastSLMForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "attn_hidden_size": -1,
7
+ "attn_implementation": "flash_attention_2",
8
+ "attn_implementation_new": "flash_attention_2",
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_fast_slm.FastSLMConfig",
11
+ "AutoModelForCausalLM": "modeling_fast_slm.FastSLMForCausalLM"
12
+ },
13
+ "bos_token_id": 1,
14
+ "calc_logits_for_entire_prompt": false,
15
+ "d_conv": 4,
16
+ "dtype": "bfloat16",
17
+ "eos_token_id": 2,
18
+ "ffn_expand_ratio": 3,
19
+ "global_attn_idx": [],
20
+ "hidden_act": "silu",
21
+ "hidden_size": 2048,
22
+ "hybrid_decoder_layer": "mamba",
23
+ "initializer_range": 0.02,
24
+ "intermediate_size": 0,
25
+ "kq_head_dim": -1,
26
+ "kq_norm": "none",
27
+ "layer_type": [
28
+ "m",
29
+ "a",
30
+ "m",
31
+ "a",
32
+ "a",
33
+ "a",
34
+ "m",
35
+ "a",
36
+ "m",
37
+ "a",
38
+ "m",
39
+ "a",
40
+ "a",
41
+ "a",
42
+ "m",
43
+ "a",
44
+ "m",
45
+ "a",
46
+ "m",
47
+ "a",
48
+ "m",
49
+ "a",
50
+ "m",
51
+ "a"
52
+ ],
53
+ "layer_types": [
54
+ "deltanet",
55
+ "f",
56
+ "m2",
57
+ "f",
58
+ "a",
59
+ "f",
60
+ "m2",
61
+ "f",
62
+ "deltanet",
63
+ "f",
64
+ "m2",
65
+ "f",
66
+ "a",
67
+ "f",
68
+ "m2",
69
+ "f",
70
+ "deltanet",
71
+ "f",
72
+ "m2",
73
+ "f",
74
+ "deltanet",
75
+ "f",
76
+ "m2",
77
+ "f"
78
+ ],
79
+ "mamba2_headdim": 64,
80
+ "mamba_conv_bias": true,
81
+ "mamba_d_conv": 4,
82
+ "mamba_d_state": 128,
83
+ "mamba_dt_rank": 128,
84
+ "mamba_expand": 2,
85
+ "mamba_inner_layernorms": true,
86
+ "mamba_proj_bias": false,
87
+ "max_position_embeddings": 36000,
88
+ "mlp_hidden_act": "silu",
89
+ "model_type": "jamba",
90
+ "new_seq_length": 2048,
91
+ "num_attention_heads": 16,
92
+ "num_experts": 1,
93
+ "num_experts_per_tok": 1,
94
+ "num_hidden_layers": 24,
95
+ "num_key_value_heads": 4,
96
+ "num_memory_tokens": 256,
97
+ "orig_max_position_embeddings": 4096,
98
+ "output_router_logits": false,
99
+ "pad_token_id": 0,
100
+ "rms_norm_eps": 1e-06,
101
+ "rope": true,
102
+ "rope_theta": 10000.0,
103
+ "rope_type": "ntk",
104
+ "router_aux_loss_coef": 0.001,
105
+ "sliding_window": null,
106
+ "tie_word_embeddings": true,
107
+ "transformers_version": "4.56.2",
108
+ "use_cache": false,
109
+ "use_mamba_kernels": true,
110
+ "v_head_dim": -1,
111
+ "vocab_size": 131072
112
+ }
configuration_fast_slm.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Jamba model configuration"""
16
+ import math
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class FastSLMConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
28
+ Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the jamba-small architecture.
30
+
31
+ [ai21labs/jamba-small](https://huggingface.co/ai21labs/Jamba-v0.1)
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 65536):
39
+ Vocabulary size of the Jamba model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`JambaModel`]
41
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
42
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
43
+ model has a output word embedding layer.
44
+ hidden_size (`int`, *optional*, defaults to 4096):
45
+ Dimension of the hidden representations.
46
+ intermediate_size (`int`, *optional*, defaults to 14336):
47
+ Dimension of the MLP representations.
48
+ num_hidden_layers (`int`, *optional*, defaults to 32):
49
+ Number of hidden layers in the Transformer encoder.
50
+ num_attention_heads (`int`, *optional*, defaults to 32):
51
+ Number of attention heads for each attention layer in the Transformer encoder.
52
+ num_key_value_heads (`int`, *optional*, defaults to 8):
53
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
54
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
55
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
56
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
57
+ by meanpooling all the original heads within that group. For more details checkout [this
58
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64
+ The epsilon used by the rms normalization layers.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ calc_logits_for_entire_prompt (`bool`, *optional*, defaults to `False`):
69
+ Whether or not to calculate logits for entire prompt during generation. If `False`, only the logits of the
70
+ last prompt token will be calculated, which are the only logits needed for generation. For long sequences,
71
+ the logits for the entire sequence may use a lot of memory so setting `calc_logits_for_entire_prompt=False`
72
+ will reduce memory footprint significantly.
73
+ Note: some generation features may not be available if this is set to `False`.
74
+ output_router_logits (`bool`, *optional*, defaults to `False`):
75
+ Whether or not the router logits should be returned by the model. Enabling this will also
76
+ allow the model to output the auxiliary loss. See [here]() for more details
77
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
78
+ The aux loss factor for the total loss.
79
+ pad_token_id (`int`, *optional*, defaults to 0):
80
+ The id of the padding token.
81
+ bos_token_id (`int`, *optional*, defaults to 1):
82
+ The id of the "beginning-of-sequence" token.
83
+ eos_token_id (`int`, *optional*, defaults to 2):
84
+ The id of the "end-of-sequence" token.
85
+ sliding_window (`int`, *optional*):
86
+ Sliding window attention window size. If not specified, will default to `None`.
87
+ n_ctx (`int`, *optional*, defaults to 262144):
88
+ This value doesn't have any real effect. The maximum sequence length that this model is intended to be
89
+ used with. It can be used with longer sequences, but performance may degrade.
90
+ attention_dropout (`float`, *optional*, defaults to 0.0):
91
+ The dropout ratio for the attention probabilities.
92
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
93
+ The number of experts to root per-token, can be also interpreted as the `top-p` routing
94
+ parameter
95
+ num_experts (`int`, *optional*, defaults to 16):
96
+ Number of experts per Sparse MLP layer.
97
+ use_mamba_kernels (`bool`, *optional*, defaults to `True`):
98
+ Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
99
+ `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if
100
+ `True` and kernels are not available
101
+ mamba_d_state (`int`, *optional*, defaults to 16):
102
+ The dimension the mamba state space latents
103
+ mamba_d_conv (`int`, *optional*, defaults to 4):
104
+ The size of the mamba convolution kernel
105
+ mamba_expand (`int`, *optional*, defaults to 2):
106
+ Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
107
+ mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
108
+ Rank of the the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
109
+ mamba_conv_bias (`bool`, *optional*, defaults to `True`):
110
+ Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
111
+ mamba_proj_bias (`bool`, *optional*, defaults to `False`):
112
+ Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
113
+ mamba_inner_layernorms (`bool`, *optional*, defaults to `True`):
114
+ Flag indicating whether or not to apply layernorms to internal mamba activations
115
+
116
+ """
117
+
118
+ model_type = "jamba"
119
+ keys_to_ignore_at_inference = ["past_key_values"]
120
+
121
+ def __init__(
122
+ self,
123
+ vocab_size=65536,
124
+ tie_word_embeddings=False,
125
+ hidden_size=4096,
126
+ intermediate_size=14336,
127
+ num_hidden_layers=32,
128
+ num_attention_heads=32,
129
+ num_key_value_heads=8,
130
+ hidden_act="silu",
131
+ initializer_range=0.02,
132
+ rms_norm_eps=1e-6,
133
+ use_cache=True,
134
+ calc_logits_for_entire_prompt=False,
135
+ output_router_logits=False,
136
+ router_aux_loss_coef=0.001,
137
+ pad_token_id=0,
138
+ bos_token_id=1,
139
+ eos_token_id=2,
140
+ sliding_window=None,
141
+ max_position_embeddings=262144,
142
+ orig_max_position_embeddings=None,
143
+ attention_dropout=0.0,
144
+ num_experts_per_tok=2,
145
+ num_experts=16,
146
+ use_mamba_kernels=True,
147
+ mamba_d_state=16,
148
+ mamba_d_conv=4,
149
+ mamba_expand=2,
150
+ mamba_dt_rank="auto",
151
+ mamba_conv_bias=True,
152
+ mamba_proj_bias=False,
153
+ mamba_inner_layernorms=True,
154
+
155
+ hybrid_decoder_layer='mamba',
156
+
157
+ global_attn_idx=None,
158
+
159
+ attn_implementation_new='flash_attention_2',
160
+
161
+ mamba2_headdim=64,
162
+
163
+ rope_type=None,
164
+
165
+ layer_types=None,
166
+
167
+ ffn_expand_ratio=None,
168
+
169
+ d_conv=4,
170
+
171
+ **kwargs,
172
+ ):
173
+ self.vocab_size = vocab_size
174
+ self.tie_word_embeddings = tie_word_embeddings
175
+ self.hidden_size = hidden_size
176
+ self.intermediate_size = intermediate_size
177
+ self.num_hidden_layers = num_hidden_layers
178
+ self.num_attention_heads = num_attention_heads
179
+ self.sliding_window = sliding_window
180
+ self.max_position_embeddings = max_position_embeddings
181
+ self.orig_max_position_embeddings = orig_max_position_embeddings
182
+ self.attention_dropout = attention_dropout
183
+
184
+ # for backward compatibility
185
+ if num_key_value_heads is None:
186
+ num_key_value_heads = num_attention_heads
187
+
188
+ self.num_key_value_heads = num_key_value_heads
189
+ self.hidden_act = hidden_act
190
+ self.initializer_range = initializer_range
191
+ self.rms_norm_eps = rms_norm_eps
192
+
193
+ self.use_cache = use_cache
194
+ self.calc_logits_for_entire_prompt = calc_logits_for_entire_prompt
195
+ self.output_router_logits = output_router_logits
196
+ self.router_aux_loss_coef = router_aux_loss_coef
197
+
198
+ self.num_experts_per_tok = num_experts_per_tok
199
+ self.num_experts = num_experts
200
+
201
+ self.use_mamba_kernels = use_mamba_kernels
202
+ self.mamba_d_state = mamba_d_state
203
+ self.mamba_d_conv = mamba_d_conv
204
+ self.mamba_expand = mamba_expand
205
+ self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
206
+ self.mamba_conv_bias = mamba_conv_bias
207
+ self.mamba_proj_bias = mamba_proj_bias
208
+ self.mamba_inner_layernorms = mamba_inner_layernorms
209
+
210
+ # added by Xin
211
+ self.kq_norm = kwargs.pop("kq_norm", None)
212
+ self.rope = kwargs.pop("rope", False)
213
+ self.rope_theta = kwargs.pop("rope_theta", 10000.0)
214
+ self.num_memory_tokens = kwargs.pop("num_memory_tokens", 0)
215
+ self.attn_hidden_size = kwargs.pop("attn_hidden_size", -1)
216
+ self.kq_head_dim = kwargs.pop("kq_head_dim", -1)
217
+ self.v_head_dim = kwargs.pop("v_head_dim", -1)
218
+
219
+ #! adhoc change
220
+ self.new_seq_length = 2048
221
+
222
+ self.hybrid_decoder_layer = hybrid_decoder_layer
223
+
224
+ self.global_attn_idx = global_attn_idx
225
+
226
+ self.attn_implementation_new = attn_implementation_new
227
+
228
+ self.mamba2_headdim = mamba2_headdim
229
+
230
+ self.rope_type = rope_type
231
+
232
+ self.layer_types = layer_types
233
+
234
+ self.ffn_expand_ratio = ffn_expand_ratio
235
+
236
+ self.d_conv = d_conv
237
+
238
+ self.mlp_hidden_act = kwargs.pop("mlp_hidden_act", "silu")
239
+
240
+ super().__init__(
241
+ pad_token_id=pad_token_id,
242
+ bos_token_id=bos_token_id,
243
+ eos_token_id=eos_token_id,
244
+ tie_word_embeddings=tie_word_embeddings,
245
+ **kwargs,
246
+ )
delta_net.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from torch.nn import functional as F
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
15
+
16
+ from typing import Any, Dict, List, Optional, Tuple
17
+
18
+ import torch
19
+ import transformers
20
+
21
+ if TYPE_CHECKING:
22
+ from transformers.processing_utils import Unpack
23
+
24
+ from fla.models.utils import Cache
25
+
26
+
27
+ def elu_p1(x):
28
+ return (F.elu(x, 1., False) + 1.).to(x)
29
+
30
+
31
+ def sum_norm(x):
32
+ return (x / x.sum(-1, keepdim=True)).to(x)
33
+
34
+
35
+ class DeltaNet(nn.Module):
36
+ r"""
37
+ The layer implementaion for [Parallelizing Linear Transformers with the Delta Rule over Sequence Length](https://arxiv.org/abs/2406.06484). # noqa:
38
+ DeltaNet was originally proposed in [Linear Transformers Are Secretly Fast Weight Programmers](https://arxiv.org/abs/2102.11174). # noqa
39
+
40
+ Args:
41
+ mode (str, Optional):
42
+ Which DeltaNet kernel to use.
43
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
44
+ Default: `chunk`.
45
+ hidden_size (int, Optional):
46
+ The hidden size of the input. Default: 1024.
47
+ expand_k (float, Optional):
48
+ The expansion ratio for the key dim. Default: 1.0.
49
+ expand_v (float, Optional):
50
+ The expansion ratio for the value dim. Default: 1.0.
51
+ num_heads (int, Optional):
52
+ The number of heads. Default: 4.
53
+ use_beta (bool, Optional):
54
+ Whether to use beta. Default: `True`.
55
+ use_gate (bool, Optional):
56
+ Whether to use output gate. Default: `False`.
57
+ use_short_conv (bool, Optional):
58
+ Whether to use short convolutions. Default: `True`.
59
+ conv_size (int, Optional):
60
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
61
+ conv_bias (bool, Optional):
62
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
63
+ allow_neg_eigval (bool, Optional):
64
+ Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2.
65
+ See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537)
66
+ layer_idx (int, Optional):
67
+ The index of the layer. Default: None.
68
+ norm_eps (float, Optional):
69
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
70
+ qk_activation (str, Optional):
71
+ The activation function for the query and key. Default: `silu`.
72
+ qk_norm (str, Optional):
73
+ The normalization method for the query and key. Default: `l2`.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ mode: str = 'chunk',
79
+ d_model: int = None,
80
+ hidden_size: int = 1024,
81
+ expand_k: float = 1.0,
82
+ expand_v: float = 1.0,
83
+ num_heads: int = 4,
84
+ use_beta: bool = True,
85
+ use_gate: bool = False,
86
+ use_short_conv: bool = True,
87
+ conv_size: int = 4,
88
+ conv_bias: bool = False,
89
+ allow_neg_eigval: bool = False,
90
+ layer_idx: int = None,
91
+ qk_activation: str = 'silu',
92
+ qk_norm: str = 'l2',
93
+ norm_eps: float = 1e-5,
94
+ config = None,
95
+ **kwargs
96
+ ) -> DeltaNet:
97
+ super().__init__()
98
+
99
+ self.mode = mode
100
+ self.qk_activation = qk_activation
101
+ self.qk_norm = qk_norm
102
+
103
+ assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
104
+ assert self.qk_norm in ['l2', 'sum']
105
+
106
+ if d_model is not None:
107
+ hidden_size = d_model
108
+ self.hidden_size = hidden_size
109
+ self.expand_k = expand_k
110
+ self.expand_v = expand_v
111
+ self.num_heads = num_heads
112
+ self.use_gate = use_gate
113
+ self.use_short_conv = use_short_conv
114
+ self.conv_size = conv_size
115
+ self.conv_bias = conv_bias
116
+ self.allow_neg_eigval = allow_neg_eigval
117
+
118
+ self.key_dim = int(hidden_size * expand_k)
119
+ self.value_dim = int(hidden_size * expand_v)
120
+ self.head_k_dim = self.key_dim // num_heads
121
+ self.head_v_dim = self.value_dim // num_heads
122
+ self.layer_idx = layer_idx
123
+
124
+ self.silu = nn.SiLU()
125
+ if mode == 'fused_chunk':
126
+ raise NotImplementedError("fused_chunk_delta_rule is now deprecated. Please use `chunk_delta_rule` instead.")
127
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
128
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
129
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
130
+
131
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
132
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
133
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
134
+
135
+ self.use_beta = use_beta
136
+ if self.use_beta:
137
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
138
+ if use_short_conv:
139
+ self.conv_size = conv_size
140
+ self.q_conv1d = ShortConvolution(
141
+ hidden_size=self.key_dim,
142
+ kernel_size=conv_size,
143
+ activation='silu' if qk_activation == 'silu' else None
144
+ )
145
+ self.k_conv1d = ShortConvolution(
146
+ hidden_size=self.key_dim,
147
+ kernel_size=conv_size,
148
+ activation='silu' if qk_activation == 'silu' else None
149
+ )
150
+ self.v_conv1d = ShortConvolution(
151
+ hidden_size=self.value_dim,
152
+ kernel_size=conv_size,
153
+ activation='silu'
154
+ )
155
+ else:
156
+ raise UserWarning(
157
+ "ShortConvolution is crucial to the performance. "
158
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
159
+ )
160
+ if use_gate:
161
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
162
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
163
+ else:
164
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
165
+
166
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
167
+
168
+ self.apply(self._initialize_weights)
169
+
170
+ def _initialize_weights(self, module: nn.Module):
171
+ if getattr(module, "_is_hf_initialized", False):
172
+ return
173
+ if isinstance(module, nn.Linear):
174
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
175
+ if module.bias is not None:
176
+ nn.init.zeros_(module.bias)
177
+ module._is_hf_initialized = True
178
+
179
+ def forward(
180
+ self,
181
+ hidden_states: torch.Tensor,
182
+ attention_mask: Optional[torch.Tensor] = None,
183
+ past_key_values: Optional[Cache] = None,
184
+ use_cache: Optional[bool] = False,
185
+ output_attentions: Optional[bool] = False,
186
+ **kwargs: Unpack[Dict]
187
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
188
+ if attention_mask is not None:
189
+ assert len(attention_mask.shape) == 2, (
190
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
191
+ "for padding purposes (0 indicating padding). "
192
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
193
+ )
194
+
195
+ # change to inference mode.
196
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
197
+
198
+ last_state = None
199
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
200
+ last_state = past_key_values[self.layer_idx]
201
+
202
+ if self.use_short_conv:
203
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
204
+ if last_state is not None:
205
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
206
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
207
+ position_ids = kwargs.get('position_ids', None)
208
+
209
+ q = self.q_proj(hidden_states)
210
+
211
+ q, conv_state_q = self.q_conv1d(x=q,
212
+ mask=conv_mask,
213
+ cache=conv_state_q,
214
+ output_final_state=use_cache,
215
+ seq_idx=position_ids)
216
+
217
+ k = self.k_proj(hidden_states)
218
+
219
+ k, conv_state_k = self.k_conv1d(x=k,
220
+ mask=conv_mask,
221
+ cache=conv_state_k,
222
+ output_final_state=use_cache,
223
+ seq_idx=position_ids)
224
+
225
+ v = self.v_proj(hidden_states)
226
+
227
+ v, conv_state_v = self.v_conv1d(x=v,
228
+ mask=conv_mask,
229
+ cache=conv_state_v,
230
+ output_final_state=use_cache,
231
+ seq_idx=position_ids)
232
+ else:
233
+ q = self.q_proj(hidden_states)
234
+ k = self.k_proj(hidden_states)
235
+ v = self.v_proj(hidden_states)
236
+
237
+ if self.qk_activation == 'silu':
238
+ q, k = self.silu(q), self.silu(k)
239
+
240
+ v = self.silu(v)
241
+
242
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
243
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
244
+ if self.qk_activation != 'silu':
245
+ if self.qk_activation == 'relu':
246
+ q, k = q.relu(), k.relu()
247
+ elif self.qk_activation == 'elu':
248
+ q, k = elu_p1(q), elu_p1(k)
249
+ elif self.qk_activation == 'identity':
250
+ pass
251
+ else:
252
+ raise NotImplementedError
253
+
254
+ if self.qk_norm == 'sum':
255
+ q = sum_norm(q).to(q)
256
+ k = sum_norm(k).to(k)
257
+
258
+ if self.use_beta:
259
+ beta = self.b_proj(hidden_states)
260
+ beta = beta.sigmoid()
261
+ else:
262
+ beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
263
+
264
+ if self.allow_neg_eigval:
265
+ beta = beta * 2.
266
+
267
+ # dealing with padding
268
+ if attention_mask is not None:
269
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
270
+
271
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
272
+
273
+ cu_seqlens = kwargs.get('cu_seqlens', None)
274
+ if mode == 'fused_recurrent':
275
+ o, recurrent_state = fused_recurrent_delta_rule(
276
+ q=q,
277
+ k=k,
278
+ v=v,
279
+ beta=beta,
280
+ initial_state=recurrent_state,
281
+ output_final_state=use_cache,
282
+ cu_seqlens=cu_seqlens,
283
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
284
+ )
285
+ elif mode == 'chunk':
286
+ o, recurrent_state = chunk_delta_rule(
287
+ q=q,
288
+ k=k,
289
+ v=v,
290
+ beta=beta,
291
+ initial_state=recurrent_state,
292
+ output_final_state=use_cache,
293
+ cu_seqlens=cu_seqlens,
294
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
295
+ )
296
+ else:
297
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
298
+
299
+ if past_key_values is not None:
300
+ past_key_values.update(
301
+ recurrent_state=recurrent_state,
302
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
303
+ layer_idx=self.layer_idx,
304
+ offset=q.shape[1]
305
+ )
306
+
307
+ if self.use_gate:
308
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
309
+ o = self.o_norm(o, g)
310
+ else:
311
+ o = self.o_norm(o)
312
+ o = rearrange(o, 'b t h d -> b t (h d)')
313
+ o = self.o_proj(o)
314
+
315
+ return o, None, past_key_values
316
+
317
+
318
+ class Cache(transformers.cache_utils.Cache):
319
+ """
320
+ A cache used for storing hidden states produced by flash linear attention models.
321
+
322
+ It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`.
323
+ """
324
+
325
+ is_compileable = True
326
+
327
+ def __init__(
328
+ self,
329
+ seen_tokens: int = 0
330
+ ) -> Cache:
331
+ super().__init__(layers=[0])
332
+
333
+ self.states: List[Dict[str, Any]] = []
334
+
335
+ self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
336
+
337
+ def __getitem__(self, layer_idx: int) -> Dict[str, Any]:
338
+ if layer_idx < len(self):
339
+ return self.states[layer_idx]
340
+ else:
341
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
342
+
343
+ def __iter__(self):
344
+ for state in self.states:
345
+ yield state
346
+
347
+ def __len__(self):
348
+ return len(self.states)
349
+
350
+ def reset(self):
351
+ for state in self.states:
352
+ for key in state:
353
+ if state[key] is not None:
354
+ if type(state[key]) == tuple:
355
+ for subkey in state[key]:
356
+ subkey.zero_()
357
+ else:
358
+ state[key].zero_()
359
+ self._seen_tokens = 0
360
+
361
+
362
+ def update(
363
+ self,
364
+ recurrent_state: Optional[Tuple[torch.Tensor]] = None,
365
+ attn_state: Optional[Tuple[torch.Tensor]] = None,
366
+ conv_state: Optional[Tuple[torch.Tensor]] = None,
367
+ ffn_state: Optional[Tuple[torch.Tensor]] = None,
368
+ layer_idx: int = 0,
369
+ offset: Optional[int] = 1,
370
+ cache_kwargs: Optional[Dict[str, Any]] = None,
371
+ ) -> Dict[str, Any]:
372
+ """
373
+ Args:
374
+ recurrent_state (`torch.Tensor`):
375
+ The new recurrent state to cache.
376
+ attn_state (`Tuple[torch.Tensor]`):
377
+ The new attention key/value states to cache.
378
+ conv_state (`Tuple[torch.Tensor]`):
379
+ The new convolution state to cache.
380
+ ffn_state (`Tuple[torch.Tensor]`):
381
+ The new feed-forward state to cache.
382
+ layer_idx (`int`, defaults to 0):
383
+ The index of the layer to cache the states for.
384
+ offset (`int`, defaults to 1):
385
+ The number of new tokens being processed.
386
+ cache_kwargs (`Dict[str, Any]`):
387
+ Additional arguments for the cache subclass.
388
+
389
+ Return:
390
+ Dictionary of the updated state.
391
+ """
392
+
393
+ if cache_kwargs is None:
394
+ cache_kwargs = {}
395
+ if attn_state is not None:
396
+ input_size = attn_state[0].shape[1]
397
+ window_size = cache_kwargs.get('window_size', None)
398
+ if not (isinstance(attn_state, Tuple) or isinstance(attn_state, List)):
399
+ raise ValueError("`attn_state` must be a tuple of tensors for key/value states")
400
+ if len(self.states) <= layer_idx:
401
+ # update the number of seen tokens
402
+ if layer_idx == 0:
403
+ self._seen_tokens += offset
404
+ if attn_state is not None:
405
+ if window_size is not None and input_size > window_size:
406
+ attn_state = [state[:, -window_size:].contiguous() for state in attn_state]
407
+ state = dict(
408
+ recurrent_state=recurrent_state,
409
+ attn_state=attn_state,
410
+ conv_state=conv_state,
411
+ ffn_state=ffn_state
412
+ )
413
+ self.states.append(state)
414
+ else:
415
+ # update the number of seen tokens
416
+ if layer_idx == len(self.states) - 1:
417
+ self._seen_tokens += offset
418
+ state = self.states[layer_idx]
419
+ if recurrent_state is not None:
420
+ state['recurrent_state'].copy_(recurrent_state)
421
+ if attn_state is not None:
422
+ if window_size is not None and state['attn_state'][0].shape[1] == window_size:
423
+ for i, (old_state, new_state) in enumerate(zip(state['attn_state'], attn_state)):
424
+ # DO NOT allocate new memory if the cache is full
425
+ # roll the key/value states to the left by `input_size`
426
+ old_state = old_state.roll(-input_size, 1)
427
+ # replace the last `input_size` tokens with the new key/value states
428
+ old_state[:, -input_size:] = new_state
429
+ state['attn_state'][i].copy_(old_state)
430
+ else:
431
+ attn_state = [
432
+ torch.cat([old_state, new_state], 1)
433
+ for old_state, new_state in zip(state['attn_state'], attn_state)
434
+ ]
435
+ state['attn_state'].copy_(attn_state)
436
+ if conv_state is not None:
437
+ conv_state_q, conv_state_k, conv_state_v = state['conv_state']
438
+ conv_state_q.copy_(conv_state[0])
439
+ conv_state_k.copy_(conv_state[1])
440
+ conv_state_v.copy_(conv_state[2])
441
+ if ffn_state is not None:
442
+ state['ffn_state'].copy_(ffn_state)
443
+
444
+ return state
445
+
446
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
447
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
448
+ if len(self.states) <= layer_idx:
449
+ return 0
450
+ return self._seen_tokens
451
+
452
+ def get_max_length(self) -> Optional[int]:
453
+ """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
454
+ return None
455
+
456
+ def to_legacy_cache(self) -> Tuple:
457
+ return tuple(self.states)
458
+
459
+ @classmethod
460
+ @torch.compiler.disable
461
+ def from_legacy_cache(
462
+ cls,
463
+ past_key_values: Optional[Tuple] = None,
464
+ seen_tokens: int = 0
465
+ ) -> Cache:
466
+ """Converts a cache in the legacy cache format into an equivalent `Cache`."""
467
+
468
+ cache = cls(seen_tokens)
469
+ if isinstance(past_key_values, list):
470
+ for layer_idx in range(len(past_key_values)):
471
+ cache.states.append(past_key_values[layer_idx])
472
+ return cache
fused_mha_with_cache.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Tuple
3
+
4
+ from .triton_attention import (
5
+ fused_mha_with_paged_cache, fused_mha_with_cache
6
+ )
7
+
8
+ dtype_int = torch.int32
9
+
10
+ def fused_mha_interface(
11
+ query_states: torch.Tensor, # [batch, q_len, heads, head_dim]
12
+ key_states: torch.Tensor, # [batch, kv_len, heads, head_dim]
13
+ value_states: torch.Tensor, # [batch, kv_len, heads, head_dim]
14
+ k_cache: torch.Tensor, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD] or [num_pages, page_size, n, d] for paged attn
15
+ v_cache: torch.Tensor, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
16
+ position_ids: torch.Tensor=None,
17
+ page_table: torch.Tensor=None, # [b, max_num_pages_per_seq] # loc of the block page in the cache.
18
+ max_seq_len = None,
19
+ ) -> torch.Tensor:
20
+ """
21
+ Replacement for _flash_attention_forward(...) that uses
22
+ Triton’s fused_mha_with_paged_cache under the hood.
23
+ Returns: [batch, q_len, heads*head_dim]
24
+ """
25
+ # unpack shapes
26
+ b, ql, n_heads, head_dim = query_states.shape
27
+ _, kvl, n_kv_heads, _ = key_states.shape
28
+
29
+ q = query_states.reshape(b, ql, n_heads * head_dim)
30
+ k = key_states.reshape(b, kvl, n_kv_heads * head_dim)
31
+ v = value_states.reshape(b, kvl, n_kv_heads * head_dim)
32
+
33
+ if position_ids is not None:
34
+ if ql == 1: # Generate phase - single token
35
+ input_pos = position_ids[:, -1] # Use the last position for each sequence
36
+ else: # Context phase - multiple tokens
37
+ input_pos = position_ids[:, 0] # Use the starting position for each sequence
38
+ else:
39
+ # Fallback: assume starting from 0 for all sequences
40
+ input_pos = torch.zeros(b, device=q.device, dtype=torch.int32)
41
+
42
+ freqs_cis = None
43
+
44
+ if page_table is None:
45
+ y = torch.ops.attention.fused_mha_with_cache(
46
+ q, k, v,
47
+ input_pos,
48
+ k_cache, v_cache,
49
+ freqs_cis,
50
+ )
51
+
52
+
53
+ else:
54
+ batch_size = b
55
+
56
+ # cache_loc: identity mapping [0, 1, ..., b-1]
57
+ cache_loc = torch.arange(batch_size, device=q.device, dtype=dtype_int)
58
+
59
+ # input_positions: assume pure context (all start from 0)
60
+ input_positions = torch.zeros(batch_size, device=q.device, dtype=dtype_int)
61
+
62
+ # seq_len: each sequence length is kvl
63
+ seq_len = torch.full((batch_size,), kvl, device=q.device, dtype=dtype_int)
64
+
65
+ # seq_start: flattened starting index for each sequence
66
+ seq_start = (seq_len.cumsum(0) - seq_len).to(dtype=dtype_int)
67
+
68
+ assert max_seq_len is not None, "max_seq_len must be provided when using paged attention."
69
+
70
+ y = torch.ops.attention.fused_mha_with_paged_cache(
71
+ q, k, v,
72
+ input_positions, cache_loc,
73
+ seq_len, seq_start,
74
+ page_table, max_seq_len,
75
+ k_cache, v_cache,
76
+ freqs_cis,
77
+ )
78
+
79
+ y = y.view(b, ql, n_heads, head_dim)
80
+
81
+ return y
82
+
83
+
84
+
85
+ def main():
86
+ #––– Test hyperparameters –––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
87
+ batch_size = 1
88
+ q_len = 1
89
+ kv_len = 1
90
+ num_heads = 16
91
+ n_kv_heads = 16
92
+ head_dim = 128
93
+
94
+ max_batch_size = 1
95
+ max_seq_len = 1024
96
+
97
+ page_size = 256
98
+
99
+ device = "cuda"
100
+
101
+ #––– Random query, key, value tensors –––––––––––––––––––––––––––––––––––––––––––––––––––
102
+ query_states = torch.randn(batch_size, q_len, num_heads, head_dim, device=device)
103
+ key_states = torch.randn(batch_size, kv_len, num_heads, head_dim, device=device)
104
+ value_states = torch.randn(batch_size, kv_len, num_heads, head_dim, device=device)
105
+
106
+ k_cache = torch.randn(max_batch_size, max_seq_len, num_heads, head_dim, device=device)
107
+ v_cache = torch.randn(max_batch_size, max_seq_len, num_heads, head_dim, device=device)
108
+
109
+ attn_out = fused_mha_interface(
110
+ query_states,
111
+ key_states,
112
+ value_states,
113
+ k_cache=k_cache,
114
+ v_cache=v_cache,
115
+ )
116
+
117
+ expected_shape = (batch_size, q_len, num_heads, head_dim)
118
+ print(f"[test] output shape: {attn_out.shape} (expected {expected_shape})")
119
+
120
+ if attn_out.shape == expected_shape:
121
+ print("[test] ✅ Success: output tensor has correct shape.")
122
+ else:
123
+ print("[test] ❌ Failure: shape mismatch.")
124
+
125
+ if __name__ == "__main__":
126
+ main()
mamba2.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat, pack, unpack
10
+
11
+ try:
12
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
13
+ except ImportError:
14
+ causal_conv1d_fn, causal_conv1d_update = None, None
15
+
16
+ try:
17
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
18
+ except ImportError:
19
+ causal_conv1d_varlen_states = None
20
+
21
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
22
+ from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
23
+
24
+
25
+ from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
26
+ from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter
27
+
28
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
29
+ from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
30
+
31
+
32
+ class Mamba2(nn.Module):
33
+ def __init__(
34
+ self,
35
+ config,
36
+ conv_init=None,
37
+ d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
38
+ ngroups=1,
39
+ A_init_range=(1, 16),
40
+ D_has_hdim=False,
41
+ rmsnorm=True,
42
+ norm_before_gate=False,
43
+ dt_min=0.001,
44
+ dt_max=0.1,
45
+ dt_init_floor=1e-4,
46
+ dt_limit=(0.0, float("inf")),
47
+ bias=False,
48
+ conv_bias=True,
49
+ # Fused kernel and sharding options
50
+ chunk_size=256,
51
+ use_mem_eff_path=False, # True,
52
+ layer_idx=None, # Absorb kwarg for general module
53
+ process_group=None,
54
+ sequence_parallel=True,
55
+ device=None,
56
+ dtype=None,
57
+ ):
58
+ factory_kwargs = {"device": device, "dtype": dtype}
59
+ super().__init__()
60
+
61
+ self.config = config
62
+ self.d_model = config.hidden_size
63
+ self.d_state = config.mamba_d_state
64
+ self.d_conv = config.mamba_d_conv
65
+
66
+ self.conv_init = conv_init
67
+ self.expand = config.mamba_expand
68
+ self.process_group = process_group
69
+ self.sequence_parallel = sequence_parallel
70
+ self.world_size = 1 if process_group is None else process_group.size()
71
+ self.local_rank = 0 if process_group is None else process_group.rank()
72
+ self.d_inner = (self.expand * self.d_model) // self.world_size
73
+ assert self.d_inner * self.world_size == self.expand * self.d_model
74
+ self.headdim = config.mamba2_headdim
75
+ self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
76
+ assert ngroups % self.world_size == 0
77
+ self.ngroups = ngroups // self.world_size
78
+ assert self.d_ssm % self.headdim == 0
79
+ self.nheads = self.d_ssm // self.headdim
80
+ self.D_has_hdim = D_has_hdim
81
+ self.rmsnorm = rmsnorm
82
+ self.norm_before_gate = norm_before_gate
83
+ self.dt_limit = dt_limit
84
+ self.activation = "silu"
85
+ self.chunk_size = chunk_size
86
+ self.use_mem_eff_path = use_mem_eff_path
87
+ self.layer_idx = layer_idx
88
+
89
+ assert (self.d_model * self.expand / self.headdim) % 8 == 0
90
+
91
+ # Order: [z, x, B, C, dt]
92
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
93
+ if self.process_group is None:
94
+ self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
95
+ else:
96
+ self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,
97
+ process_group=self.process_group, sequence_parallel=self.sequence_parallel,
98
+ **factory_kwargs)
99
+
100
+ conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
101
+ self.conv1d = nn.Conv1d(
102
+ in_channels=conv_dim,
103
+ out_channels=conv_dim,
104
+ bias=conv_bias,
105
+ kernel_size=self.d_conv,
106
+ groups=conv_dim,
107
+ padding=self.d_conv - 1,
108
+ **factory_kwargs,
109
+ )
110
+ if self.conv_init is not None:
111
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
112
+
113
+ self.act = nn.SiLU()
114
+
115
+ # Initialize log dt bias
116
+ dt = torch.exp(
117
+ torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
118
+ + math.log(dt_min)
119
+ )
120
+ dt = torch.clamp(dt, min=dt_init_floor)
121
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
122
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
123
+
124
+ self.dt_bias = nn.Parameter(inv_dt)
125
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
126
+ # name.endswith("bias") in param_grouping.py
127
+ self.dt_bias._no_weight_decay = True
128
+
129
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
130
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
131
+ A_log = torch.log(A).to(dtype=dtype)
132
+ self.A_log = nn.Parameter(A_log)
133
+ self.A_log._no_weight_decay = True
134
+
135
+ # D "skip" parameter
136
+ self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))
137
+ self.D._no_weight_decay = True
138
+
139
+ if self.rmsnorm:
140
+ assert RMSNormGated is not None
141
+ self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
142
+ group_size=self.d_ssm // ngroups, **factory_kwargs)
143
+
144
+ if self.process_group is None:
145
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
146
+ else:
147
+ self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,
148
+ process_group=self.process_group, sequence_parallel=self.sequence_parallel,
149
+ **factory_kwargs)
150
+
151
+
152
+ def forward(self, hidden_states, attention_mask=None, past_key_value=None, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
153
+ """
154
+ hidden_states: (batch, seqlen, hidden_dim) if seqlen=None.
155
+ If seqlen is not None, hidden_states is (batch * seqlen, hidden_dim). This is so that when we
156
+ split hidden_states during sequence parallel, we split the batch * seqlen dimension
157
+ (in case batch is small).
158
+ Returns: same shape as u
159
+ """
160
+ # assert past_key_value is None, "Not implemented yet!!!"
161
+
162
+ seqlen_og = seqlen
163
+ if seqlen is None:
164
+ batch, seqlen, dim = hidden_states.shape
165
+ else:
166
+ batch_seqlen, dim = hidden_states.shape
167
+ batch = batch_seqlen // seqlen
168
+
169
+ conv_state, ssm_state = None, None
170
+
171
+ if inference_params is not None:
172
+ inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
173
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
174
+
175
+ if inference_params.seqlen_offset > 0:
176
+ # The states are updated inplace
177
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
178
+ return out, past_key_value
179
+
180
+ zxbcdt = self.in_proj(hidden_states) # (B, L, d_in_proj) or (B * L, d_in_proj)
181
+
182
+ if seqlen_og is not None:
183
+ zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
184
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
185
+ A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
186
+ dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
187
+ if self.use_mem_eff_path and inference_params is None:
188
+ out = mamba_split_conv1d_scan_combined(
189
+ zxbcdt,
190
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
191
+ self.conv1d.bias,
192
+ self.dt_bias,
193
+ A,
194
+ D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
195
+ chunk_size=self.chunk_size,
196
+ seq_idx=seq_idx,
197
+ activation=self.activation,
198
+ rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
199
+ rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
200
+ outproj_weight=self.out_proj.weight,
201
+ outproj_bias=self.out_proj.bias,
202
+ headdim=None if self.D_has_hdim else self.headdim,
203
+ ngroups=self.ngroups,
204
+ norm_before_gate=self.norm_before_gate,
205
+ **dt_limit_kwargs,
206
+ )
207
+ if seqlen_og is not None:
208
+ out = rearrange(out, "b l d -> (b l) d")
209
+ if self.process_group is not None:
210
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
211
+ out = reduce_fn(out, self.process_group)
212
+ else:
213
+ d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
214
+ z0, x0, z, xBC, dt = torch.split(
215
+ zxbcdt,
216
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
217
+ dim=-1
218
+ )
219
+
220
+ if conv_state is not None:
221
+ if cu_seqlens is None:
222
+ # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
223
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
224
+ xBC_t = rearrange(xBC, "b l d -> b d l")
225
+ conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
226
+ else:
227
+ assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
228
+ assert batch == 1, "varlen inference only supports batch dimension 1"
229
+ conv_varlen_states = causal_conv1d_varlen_states(
230
+ xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
231
+ )
232
+ conv_state.copy_(conv_varlen_states)
233
+ assert self.activation in ["silu", "swish"]
234
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
235
+ assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
236
+ xBC = self.act(
237
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
238
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
239
+ else:
240
+ xBC = causal_conv1d_fn(
241
+ xBC.transpose(1, 2),
242
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
243
+ bias=self.conv1d.bias,
244
+ activation=self.activation,
245
+ # seq_idx=seq_idx,
246
+ ).transpose(1, 2)
247
+
248
+ x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
249
+
250
+
251
+ y = mamba_chunk_scan_combined(
252
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
253
+ dt,
254
+ A,
255
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
256
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
257
+ chunk_size=self.chunk_size,
258
+ # D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
259
+ D=self.D,
260
+ z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
261
+ dt_bias=self.dt_bias,
262
+ dt_softplus=True,
263
+ seq_idx=seq_idx,
264
+ cu_seqlens=cu_seqlens,
265
+ **dt_limit_kwargs,
266
+ return_final_states=ssm_state is not None,
267
+ return_varlen_states=cu_seqlens is not None and inference_params is not None,
268
+ )
269
+ if ssm_state is not None:
270
+ y, last_state, *rest = y
271
+ if cu_seqlens is None:
272
+ ssm_state.copy_(last_state)
273
+ else:
274
+ varlen_states = rest[0]
275
+ ssm_state.copy_(varlen_states)
276
+ y = rearrange(y, "b l h p -> b l (h p)")
277
+ if self.rmsnorm:
278
+ y_full = y
279
+ z_full = z
280
+
281
+ y = self.norm(y_full, z_full)
282
+ if d_mlp > 0:
283
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
284
+ if seqlen_og is not None:
285
+ y = rearrange(y, "b l d -> (b l) d")
286
+
287
+ out = self.out_proj(y)
288
+
289
+ return out, past_key_value
290
+
291
+
292
+ def step(self, hidden_states, conv_state, ssm_state):
293
+ dtype = hidden_states.dtype
294
+ # Remove single token limitation - now supports hidden_states.shape[1] > 1
295
+ batch_size, seq_len, _ = hidden_states.shape
296
+
297
+ if seq_len == 1:
298
+ # Single token case - keep existing optimized path
299
+ zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
300
+ else:
301
+ # Multi-token case - process without squeezing
302
+ zxbcdt = self.in_proj(hidden_states) # (B L 2D)
303
+
304
+ d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
305
+
306
+ if seq_len == 1:
307
+ z0, x0, z, xBC, dt = torch.split(
308
+ zxbcdt,
309
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
310
+ dim=-1
311
+ )
312
+ else:
313
+ z0, x0, z, xBC, dt = torch.split(
314
+ zxbcdt,
315
+ [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],
316
+ dim=-1
317
+ )
318
+
319
+ # Conv step - handle both single and multi-token cases
320
+ if seq_len == 1:
321
+ # Single token optimized path
322
+ if causal_conv1d_update is None:
323
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
324
+ conv_state[:, :, -1] = xBC
325
+ xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
326
+ if self.conv1d.bias is not None:
327
+ xBC = xBC + self.conv1d.bias
328
+ xBC = self.act(xBC).to(dtype=dtype)
329
+ else:
330
+ xBC = causal_conv1d_update(
331
+ xBC,
332
+ conv_state,
333
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
334
+ self.conv1d.bias,
335
+ self.activation,
336
+ )
337
+ else:
338
+ # Multi-token case - update conv_state and process sequence
339
+ # Update conv_state with the new sequence
340
+ xBC_t = rearrange(xBC, "b l d -> b d l")
341
+ conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
342
+
343
+ # Process convolution for the full sequence
344
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
345
+ xBC = self.act(
346
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.d_conv - 1):]
347
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
348
+ else:
349
+ xBC = causal_conv1d_fn(
350
+ xBC.transpose(1, 2),
351
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
352
+ bias=self.conv1d.bias,
353
+ activation=self.activation,
354
+ ).transpose(1, 2)
355
+
356
+ x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
357
+ A = -torch.exp(self.A_log.float()) # (nheads,)
358
+
359
+ # SSM step - handle both single and multi-token cases
360
+ if seq_len == 1:
361
+ # Single token optimized path
362
+ if selective_state_update is None:
363
+ assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
364
+ # Discretize A and B
365
+ dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
366
+ dA = torch.exp(dt * A) # (batch, nheads)
367
+ x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
368
+ dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
369
+ ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
370
+ y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
371
+ y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
372
+ y = rearrange(y, "b h p -> b (h p)")
373
+ if not self.rmsnorm:
374
+ y = y * self.act(z) # (B D)
375
+ else:
376
+ A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
377
+ dt = repeat(dt, "b h -> b h p", p=self.headdim)
378
+ dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
379
+ D = repeat(self.D, "h -> h p", p=self.headdim)
380
+ B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
381
+ C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
382
+ x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
383
+ if not self.rmsnorm:
384
+ z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
385
+ y = selective_state_update(
386
+ ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,
387
+ dt_bias=dt_bias, dt_softplus=True
388
+ )
389
+ y = rearrange(y, "b h p -> b (h p)")
390
+ else:
391
+ # Multi-token case - use mamba_chunk_scan_combined similar to forward method
392
+ dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
393
+
394
+ y = mamba_chunk_scan_combined(
395
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
396
+ dt,
397
+ A,
398
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
399
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
400
+ chunk_size=self.chunk_size,
401
+ D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
402
+ z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None,
403
+ dt_bias=self.dt_bias,
404
+ dt_softplus=True,
405
+ **dt_limit_kwargs,
406
+ return_final_states=True,
407
+ )
408
+ # Extract final state and update ssm_state
409
+ y, final_ssm_state = y
410
+ ssm_state.copy_(final_ssm_state)
411
+ y = rearrange(y, "b l h p -> b l (h p)")
412
+
413
+ if self.rmsnorm:
414
+ y = self.norm(y, z)
415
+ if d_mlp > 0:
416
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
417
+ out = self.out_proj(y)
418
+
419
+ # Ensure output shape consistency
420
+ if seq_len == 1 and out.dim() == 2:
421
+ out = out.unsqueeze(1) # (B, 1, D)
422
+
423
+ return out, conv_state, ssm_state
424
+
425
+
426
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
427
+ device = self.out_proj.weight.device
428
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
429
+ conv_state = torch.zeros(
430
+ batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype
431
+ ).transpose(1, 2)
432
+ ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
433
+ ssm_state = torch.zeros(
434
+ batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
435
+ )
436
+ return conv_state, ssm_state
437
+
438
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
439
+ assert self.layer_idx is not None
440
+ if self.layer_idx not in inference_params.key_value_memory_dict:
441
+ batch_shape = (batch_size,)
442
+ conv_state = torch.zeros(
443
+ batch_size,
444
+ self.d_conv,
445
+ self.conv1d.weight.shape[0],
446
+ device=self.conv1d.weight.device,
447
+ dtype=self.conv1d.weight.dtype,
448
+ ).transpose(1, 2)
449
+ ssm_state = torch.zeros(
450
+ batch_size,
451
+ self.nheads,
452
+ self.headdim,
453
+ self.d_state,
454
+ device=self.in_proj.weight.device,
455
+ dtype=self.in_proj.weight.dtype,
456
+ )
457
+ inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
458
+ else:
459
+ conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
460
+ # TODO: What if batch size changes between generation, and we reuse the same states?
461
+ if initialize_states:
462
+ conv_state.zero_()
463
+ ssm_state.zero_()
464
+ return conv_state, ssm_state
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fdb6a7f8e726e8e00f5f63dde79d18b499722f6ce7199027679cd01c9f71a87
3
+ size 1930804728
modeling_fast_slm.py ADDED
The diff for this file is too large to render. See raw diff
 
triton_attention.py ADDED
@@ -0,0 +1,2714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom ops for MHA/XQA attention."""
2
+
3
+ import math
4
+ from dataclasses import astuple
5
+ from typing import List, Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import triton
10
+
11
+ from triton import language as tl
12
+
13
+ from abc import ABC, abstractmethod
14
+ from dataclasses import dataclass, field, fields
15
+ from typing import Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union
16
+
17
+ import torch
18
+ from torch.export import Dim
19
+
20
+
21
+ @triton.jit
22
+ def update_kv_cache(
23
+ k_ptr, # [B*S, N, D]
24
+ v_ptr, # [B*S, N, D]
25
+ seq_len_ptr, # [b] # length of each sequence in a batch
26
+ seq_start_indices_ptr, # [b] # start indices of a sequence in flattened q/k/v.
27
+ k_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
28
+ v_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
29
+ input_pos_ptr, # Specifies the sequence index in the caches at which to write the provided kv
30
+ cache_loc_ptr, # Specifies the batch index for each of the input sequences
31
+ MAX_SEQ_LENGTH: tl.constexpr,
32
+ N_KV_HEADS: tl.constexpr,
33
+ Q_D_HEAD: tl.constexpr,
34
+ V_D_HEAD: tl.constexpr,
35
+ SEQ_BLOCK: tl.constexpr,
36
+ GENERATE_ONLY: tl.constexpr,
37
+ ):
38
+ batch_id = tl.program_id(axis=0)
39
+ head_id = tl.program_id(axis=1)
40
+ seq_block_id = tl.program_id(axis=2)
41
+
42
+ # Each program is responsible for a block of tokens in a single batch.
43
+ if GENERATE_ONLY:
44
+ seq_start_index = batch_id
45
+ seq_len: tl.constexpr = 1
46
+ else:
47
+ seq_start_index = tl.load(seq_start_indices_ptr + batch_id)
48
+ seq_len = tl.load(seq_len_ptr + batch_id)
49
+
50
+ # cache is [bsnd]
51
+ # cache_loc_ptr stores the batch index for the sequences provided to the kernel.
52
+ cache_loc = tl.load(cache_loc_ptr + batch_id)
53
+
54
+ kv_position = tl.load(input_pos_ptr + batch_id)
55
+
56
+ K_D_HEAD: tl.constexpr = Q_D_HEAD
57
+ k_cache_batch_offset = cache_loc * N_KV_HEADS * MAX_SEQ_LENGTH * K_D_HEAD
58
+ v_cache_batch_offset = cache_loc * N_KV_HEADS * MAX_SEQ_LENGTH * V_D_HEAD
59
+
60
+ k_dhead_offsets = tl.arange(0, triton.next_power_of_2(K_D_HEAD))
61
+ k_dhead_mask = k_dhead_offsets < K_D_HEAD
62
+
63
+ v_dhead_offsets = tl.arange(0, triton.next_power_of_2(V_D_HEAD))
64
+ v_dhead_mask = v_dhead_offsets < V_D_HEAD
65
+
66
+ seq_offsets = seq_block_id * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
67
+ seq_mask = seq_offsets < seq_len
68
+
69
+ k_load_mask = seq_mask[:, None] * k_dhead_mask[None, :]
70
+ v_load_mask = seq_mask[:, None] * v_dhead_mask[None, :]
71
+
72
+ k_batch_offset = seq_start_index * N_KV_HEADS * K_D_HEAD
73
+ v_batch_offset = seq_start_index * N_KV_HEADS * V_D_HEAD
74
+ # Write back to kv-caches
75
+ ks = tl.load(
76
+ k_ptr
77
+ + k_batch_offset
78
+ + seq_offsets[:, None] * N_KV_HEADS * K_D_HEAD
79
+ + head_id * K_D_HEAD
80
+ + k_dhead_offsets[None, :],
81
+ mask=k_load_mask,
82
+ )
83
+ vs = tl.load(
84
+ v_ptr
85
+ + v_batch_offset
86
+ + seq_offsets[:, None] * N_KV_HEADS * V_D_HEAD
87
+ + head_id * V_D_HEAD
88
+ + v_dhead_offsets[None, :],
89
+ mask=v_load_mask,
90
+ )
91
+
92
+ kv_writeback_seq_offsets = seq_offsets + kv_position
93
+
94
+ k_cache_offset = (
95
+ k_cache_batch_offset
96
+ + kv_writeback_seq_offsets[:, None] * K_D_HEAD * N_KV_HEADS
97
+ + head_id * K_D_HEAD
98
+ + k_dhead_offsets[None, :]
99
+ )
100
+
101
+ v_cache_offset = (
102
+ v_cache_batch_offset
103
+ + kv_writeback_seq_offsets[:, None] * V_D_HEAD * N_KV_HEADS
104
+ + head_id * V_D_HEAD
105
+ + v_dhead_offsets[None, :]
106
+ )
107
+ tl.store(k_cache_ptr + k_cache_offset, ks, k_load_mask)
108
+ tl.store(v_cache_ptr + v_cache_offset, vs, v_load_mask)
109
+
110
+
111
+ @triton.jit
112
+ def gqa_attention_kv_stage1(
113
+ q_ptr, # [Batch, 1, N_HEADS, D_HEAD]
114
+ k_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
115
+ v_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
116
+ cache_loc_ptr, # [Batch] # Specifies the batch index for each of the generate tokens.
117
+ input_pos_ptr, # [Batch]
118
+ output_values_ptr, # [Batch, N_HEADS, num_blocks, D_HEAD]
119
+ output_logsumexp_ptr, # [Batch, N_HEADS, num_blocks]
120
+ num_blocks,
121
+ MAX_SEQ_LEN: tl.constexpr, # Maximum supported sequence length
122
+ N_HEADS: tl.constexpr, # Number of heads
123
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
124
+ Q_D_HEAD: tl.constexpr, # Dimension of each query head.
125
+ V_D_HEAD: tl.constexpr, # Dimension of each key/value head
126
+ SEQ_BLOCK_SIZE: tl.constexpr, # Block size used for tiling the sequence dim.
127
+ HEAD_BLOCK_SIZE: tl.constexpr, # pad to 16 if HEAD_RATIO is < 16 to invoke tensor cores.
128
+ ):
129
+ """Attention kernel to be used for generate-only batches.
130
+
131
+ Specialized for GQA.
132
+
133
+ Assumes that kv caches have been updated.
134
+
135
+ Supports non-power-of-2 D_HEAD
136
+
137
+ Uses flash decoding.
138
+ KV-cache layout is assumed to be [Batch,Seq, Head, Dim]
139
+ 1. Fetch the K-cache from 0 to input_pos
140
+ 2. Fetch the V-cache from 0 to input_pos
141
+ 3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len]
142
+ 4. S = softmax(A)
143
+ 5. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD]
144
+ """
145
+ # Assume KV-cache layout: [Batch, Seq, Head, Dim]
146
+ # A program is responsible for 1 batch, 1 head and a block of sequences.
147
+ batch_id = tl.program_id(axis=0)
148
+ kv_head_id = tl.program_id(axis=1)
149
+ seq_block_id = tl.program_id(axis=2)
150
+
151
+ kv_position = tl.load(input_pos_ptr + batch_id)
152
+ kv_batch_id = tl.load(cache_loc_ptr + batch_id)
153
+ K_D_HEAD: tl.constexpr = Q_D_HEAD
154
+ batch_offset = kv_batch_id * N_KV_HEADS * MAX_SEQ_LEN
155
+
156
+ # Offsets for the block of sequences this program processes.
157
+ seq_start_pos = seq_block_id * SEQ_BLOCK_SIZE
158
+
159
+ # The number of Q heads that map to each KV head.
160
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS # This needs to be a power-of-2
161
+ if seq_start_pos > kv_position:
162
+ return
163
+ seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
164
+ seq_mask = seq_offsets <= kv_position
165
+
166
+ # Need to pad the head dim to 16 if HEAD_RATIO is < 16 so that tensor cores can be invoked
167
+ #
168
+ head_offsets = kv_head_id * HEAD_RATIO + tl.arange(0, HEAD_BLOCK_SIZE)
169
+ head_mask = head_offsets < (kv_head_id * HEAD_RATIO + HEAD_RATIO)
170
+ # Assuming D_HEAD is a power of 2
171
+ q_dhead_offsets = tl.arange(0, triton.next_power_of_2(Q_D_HEAD))
172
+ q_dhead_mask = q_dhead_offsets < Q_D_HEAD
173
+
174
+ v_dhead_offsets = tl.arange(0, triton.next_power_of_2(V_D_HEAD))
175
+ v_dhead_mask = v_dhead_offsets < V_D_HEAD
176
+
177
+ sm_scale: tl.constexpr = 1.0 / (Q_D_HEAD**0.5)
178
+
179
+ # Program loads the entire Q for the head assigned to it.
180
+ # [NUM_HEADS, Q_D_HEAD]
181
+ q_batch_offset = batch_id * N_HEADS * Q_D_HEAD
182
+ q_head_offsets = head_offsets * Q_D_HEAD
183
+
184
+ # Q layout : BSND
185
+ q = tl.load(
186
+ q_ptr + q_batch_offset + q_head_offsets[:, None] + q_dhead_offsets[None, :],
187
+ mask=head_mask[:, None] * q_dhead_mask[None, :],
188
+ other=0.0,
189
+ )
190
+
191
+ # [BSND]
192
+ k_block_offsets = (
193
+ batch_offset * K_D_HEAD
194
+ + seq_offsets[:, None] * K_D_HEAD * N_KV_HEADS
195
+ + kv_head_id * K_D_HEAD
196
+ + q_dhead_offsets[None, :]
197
+ )
198
+ k_mask = seq_mask[:, None] * q_dhead_mask[None, :] # K and Q share the same head dim
199
+ k = tl.load(k_cache_ptr + k_block_offsets, mask=k_mask, other=0.0)
200
+
201
+ v_block_offsets = (
202
+ batch_offset * V_D_HEAD
203
+ + seq_offsets[:, None] * V_D_HEAD * N_KV_HEADS
204
+ + kv_head_id * V_D_HEAD
205
+ + v_dhead_offsets[None, :]
206
+ )
207
+ v_mask = seq_mask[:, None] * v_dhead_mask[None, :]
208
+
209
+ # [seq_block, V_D_HEAD]
210
+ v = tl.load(v_cache_ptr + v_block_offsets, mask=v_mask, other=0.0)
211
+
212
+ # Note: check the output precision of the sum.
213
+ # compute q*K^T
214
+ # [NUM_HEADS, Q_D_HEAD] * [seq_block, Q_D_HEAD], sum along axis 1
215
+ attn = tl.dot(q, k.trans()) # [N, seq_block]
216
+ attn = attn.to(tl.float32)
217
+ attn *= sm_scale
218
+ max_attn = tl.max(attn, axis=1) # [N, 1]
219
+ # Set to -inf attn values where mask is not set. This forces exp(attn) to 0.
220
+ attn = tl.where(head_mask[:, None] * seq_mask[None, :], attn, float("-inf"))
221
+ exp_attn = tl.exp(attn - max_attn[:, None])
222
+
223
+ sumexp = tl.sum(exp_attn, axis=1) # [N, 1]
224
+
225
+ # [NUM_HEADS, seq_len] * [seq_len, V_D_HEAD], sum along axis 0
226
+ output = tl.dot(exp_attn.to(v.dtype), v)
227
+
228
+ output = output / sumexp[:, None] # [N, D_HEAD]
229
+
230
+ # We store the log-sum-exp after removing the max.
231
+ logsumexp = tl.log(sumexp) + max_attn
232
+ # when seq_mask is all false, max_attn will be -inf and sumexp is zero
233
+
234
+ tl.store(
235
+ output_values_ptr
236
+ + batch_id * N_HEADS * V_D_HEAD * num_blocks
237
+ + head_offsets[:, None] * V_D_HEAD * num_blocks
238
+ + seq_block_id * V_D_HEAD
239
+ + v_dhead_offsets[None, :],
240
+ output,
241
+ mask=head_mask[:, None] * v_dhead_mask[None, :],
242
+ )
243
+ tl.store(
244
+ output_logsumexp_ptr
245
+ + batch_id * N_HEADS * num_blocks
246
+ + head_offsets * num_blocks
247
+ + seq_block_id,
248
+ logsumexp,
249
+ mask=head_mask,
250
+ )
251
+
252
+
253
+ @triton.jit
254
+ def attention_kv_stage1(
255
+ q_ptr, # [Batch, 1, N_HEADS, D_HEAD]
256
+ k_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
257
+ v_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
258
+ cache_loc_ptr, # [Batch] # Specifies the batch index for each of the generate tokens.
259
+ input_pos_ptr, # [Batch]
260
+ output_values_ptr, # [Batch, N_HEADS, num_blocks, D_HEAD]
261
+ output_logsumexp_ptr, # [Batch, N_HEADS, num_blocks]
262
+ num_blocks,
263
+ MAX_SEQ_LEN: tl.constexpr, # Maximum supported sequence length
264
+ N_HEADS: tl.constexpr, # Number of heads
265
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
266
+ D_HEAD: tl.constexpr, # Dimension of each head.
267
+ SEQ_BLOCK_SIZE: tl.constexpr, # Block size used for tiling the sequence dim.
268
+ ):
269
+ """Attention kernel to be used for generate-only batches.
270
+
271
+ Assumes that kv caches have been updated.
272
+
273
+ Uses flash decoding.
274
+ KV-cache layout is assumed to be [Batch,Seq, Head, Dim]
275
+ 1. Fetch the K-cache from 0 to input_pos
276
+ 2. Fetch the V-cache from 0 to input_pos
277
+ 3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len]
278
+ 4. S = softmax(A)
279
+ 5. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD]
280
+ """
281
+ # Assume KV-cache layout: [Batch, Seq, Head, Dim]
282
+ # A program is responsible for 1 batch, 1 head and a block of sequences.
283
+ batch_id = tl.program_id(axis=0)
284
+ head_id = tl.program_id(axis=1)
285
+ seq_block_id = tl.program_id(axis=2)
286
+ epsilon: tl.constexpr = 1e-38 # float32 smallest positive number
287
+
288
+ kv_position = tl.load(input_pos_ptr + batch_id)
289
+ kv_batch_id = tl.load(cache_loc_ptr + batch_id)
290
+ kv_batch_offset = kv_batch_id * N_KV_HEADS * MAX_SEQ_LEN * D_HEAD
291
+ # Offsets for the block of sequences this program processes.
292
+ seq_start_pos = seq_block_id * SEQ_BLOCK_SIZE
293
+
294
+ if seq_start_pos > kv_position:
295
+ return
296
+ seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
297
+ seq_mask = seq_offsets <= kv_position
298
+ # Assuming D_HEAD is a power of 2
299
+ dhead_offsets = tl.arange(0, triton.next_power_of_2(D_HEAD))
300
+ dhead_mask = dhead_offsets < D_HEAD
301
+
302
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
303
+ kv_head_offset = (head_id // HEAD_RATIO) * D_HEAD
304
+
305
+ sm_scale: tl.constexpr = 1.0 / (D_HEAD**0.5)
306
+
307
+ # Program loads the entire Q for the head assigned to it.
308
+ # [D_HEAD]
309
+ q_batch_offset = batch_id * N_HEADS * D_HEAD
310
+ q_head_offset = head_id * D_HEAD
311
+ q = tl.load(q_ptr + q_batch_offset + q_head_offset + dhead_offsets, mask=dhead_mask)
312
+
313
+ kv_block_offsets = (
314
+ kv_batch_offset
315
+ + seq_offsets[:, None] * D_HEAD * N_KV_HEADS
316
+ + kv_head_offset
317
+ + dhead_offsets[None, :]
318
+ ) # [BSND]
319
+ kv_mask = seq_mask[:, None] * dhead_mask[None, :]
320
+
321
+ # [seq_block, D_HEAD]
322
+ k = tl.load(k_cache_ptr + kv_block_offsets, mask=kv_mask, other=0.0)
323
+ v = tl.load(v_cache_ptr + kv_block_offsets, mask=kv_mask, other=0.0)
324
+
325
+ # Note: check the output precision of the sum.
326
+ # compute q*K^T
327
+ # [D_HEAD] * [seq_block, D_HEAD], sum along axis 1
328
+ attn = tl.sum(q[None, :].to(tl.float32) * k.to(tl.float32), axis=1) # [seq_block]
329
+
330
+ attn *= sm_scale
331
+ max_attn = tl.max(attn)
332
+ # Set to -inf attn values where mask is not set. This forces exp(attn) to 0.
333
+ attn = tl.where(seq_mask, attn, float("-inf"))
334
+ exp_attn = tl.exp(attn - max_attn)
335
+ exp_attn = tl.where(exp_attn == 0, epsilon, exp_attn)
336
+ sumexp = tl.sum(exp_attn, axis=0) # scalar.
337
+
338
+ # [seq_len] * [seq_len, D_HEAD], sum along axis 0
339
+ output = tl.sum(exp_attn[:, None] * v, axis=0) # [D_HEAD]
340
+
341
+ output = output / sumexp
342
+
343
+ # We store the log-sum-exp after removing the max.
344
+ logsumexp = tl.log(sumexp) + max_attn
345
+ # when seq_mask is all false, max_attn will be -inf and sumexp is zero
346
+
347
+ tl.store(
348
+ output_values_ptr
349
+ + batch_id * N_HEADS * D_HEAD * num_blocks
350
+ + head_id * D_HEAD * num_blocks
351
+ + seq_block_id * D_HEAD
352
+ + dhead_offsets,
353
+ output,
354
+ mask=dhead_mask,
355
+ )
356
+ tl.store(
357
+ output_logsumexp_ptr
358
+ + batch_id * N_HEADS * num_blocks
359
+ + head_id * num_blocks
360
+ + seq_block_id,
361
+ logsumexp,
362
+ )
363
+
364
+
365
+ @triton.jit
366
+ def attention_kv_stage2(
367
+ values_ptr, # [Batch, N_HEADS, num_blocks, D_HEAD]
368
+ logsumexp_ptr, # [Batch, N_HEADS, num_blocks]
369
+ output_ptr, # [Batch, N_HEADS, D_HEAD]
370
+ input_pos_ptr,
371
+ NUM_BLOCKS: tl.constexpr,
372
+ N_HEADS: tl.constexpr,
373
+ D_HEAD: tl.constexpr,
374
+ SEQ_BLOCK_SIZE: tl.constexpr, # Nearest power of 2 for num_blocks
375
+ ):
376
+ # There are batch * N_HEADS programs
377
+ batch_id = tl.program_id(axis=0)
378
+ head_id = tl.program_id(axis=1)
379
+
380
+ dhead_offsets = tl.arange(0, triton.next_power_of_2(D_HEAD))
381
+ dhead_mask = dhead_offsets < D_HEAD
382
+
383
+ kv_position = tl.load(input_pos_ptr + batch_id)
384
+ block_id = kv_position // SEQ_BLOCK_SIZE + 1
385
+
386
+ NUM_BLOCKS_POW2: tl.constexpr = triton.next_power_of_2(NUM_BLOCKS)
387
+ block_offsets = tl.arange(0, NUM_BLOCKS_POW2)
388
+
389
+ block_mask = block_offsets < block_id
390
+ logsumexp = tl.load(
391
+ logsumexp_ptr + batch_id * N_HEADS * NUM_BLOCKS + head_id * NUM_BLOCKS + block_offsets,
392
+ mask=block_mask,
393
+ other=float("-inf"),
394
+ )
395
+ max_logsumexp = tl.max(logsumexp)
396
+ sumexp = tl.exp(logsumexp - max_logsumexp) # [NUM_BLOCKS_POW2]
397
+
398
+ aggregate_sumexp = tl.sum(sumexp, axis=0)
399
+
400
+ values_offsets = block_offsets[:, None] * D_HEAD + dhead_offsets[None, :]
401
+ values_mask = block_mask[:, None] * dhead_mask[None, :]
402
+
403
+ values = tl.load(
404
+ values_ptr
405
+ + batch_id * N_HEADS * D_HEAD * NUM_BLOCKS
406
+ + head_id * D_HEAD * NUM_BLOCKS
407
+ + values_offsets,
408
+ mask=values_mask,
409
+ other=0.0,
410
+ ) # [BLOCK_SIZE, D_HEAD]
411
+ values *= sumexp[:, None]
412
+ values /= aggregate_sumexp
413
+
414
+ output = tl.sum(values, axis=0) # [DHEAD]
415
+
416
+ tl.store(
417
+ output_ptr + batch_id * N_HEADS * D_HEAD + head_id * D_HEAD + dhead_offsets,
418
+ output,
419
+ mask=dhead_mask,
420
+ )
421
+
422
+
423
+ @triton.jit
424
+ def context_attention_kv(
425
+ q_ptr, # [bsnd]
426
+ k_ptr, # [bsnd]
427
+ v_ptr, # [bsnd]
428
+ k_cache_ptr, # [bsnd]
429
+ v_cache_ptr, # [bsnd]
430
+ seq_len,
431
+ o_ptr,
432
+ softmax_scale,
433
+ N_HEADS: tl.constexpr, # Number of heads
434
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
435
+ Q_D_HEAD: tl.constexpr, # Dimension of each query head.
436
+ V_D_HEAD: tl.constexpr, # Dimension of each value head.
437
+ SEQ_BLOCK: tl.constexpr,
438
+ MAX_SEQ_LENGTH: tl.constexpr,
439
+ ):
440
+ """Kernel for context phase.
441
+
442
+ Assuming:
443
+ 1. Self-attention [seqlen(Q) == seqlen(K)]
444
+ 2. Causal attention
445
+ 3. QKV layout: [bsnd]
446
+ """
447
+ batch_id = tl.program_id(axis=0)
448
+ head_id = tl.program_id(axis=1)
449
+ seq_block_id = tl.program_id(axis=2)
450
+
451
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
452
+ K_D_HEAD: tl.constexpr = Q_D_HEAD
453
+
454
+ q_dhead_offsets = tl.arange(0, triton.next_power_of_2(Q_D_HEAD))
455
+ q_dhead_mask = q_dhead_offsets < Q_D_HEAD
456
+
457
+ v_dhead_offsets = tl.arange(0, triton.next_power_of_2(V_D_HEAD))
458
+ v_dhead_mask = v_dhead_offsets < V_D_HEAD
459
+
460
+ seq_offsets = seq_block_id * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
461
+ seq_mask = seq_offsets < seq_len
462
+
463
+ q_load_mask = seq_mask[:, None] * q_dhead_mask[None, :]
464
+
465
+ q_batch_offset = batch_id * seq_len * N_HEADS
466
+ kv_batch_offset = batch_id * seq_len * N_KV_HEADS
467
+
468
+ k_head_offset = (head_id // HEAD_RATIO) * K_D_HEAD
469
+ v_head_offset = (head_id // HEAD_RATIO) * V_D_HEAD
470
+
471
+ # Q will stay in SRAM
472
+ q = tl.load(
473
+ q_ptr
474
+ + q_batch_offset * Q_D_HEAD
475
+ + seq_offsets[:, None] * N_HEADS * Q_D_HEAD
476
+ + head_id * Q_D_HEAD
477
+ + q_dhead_offsets[None, :],
478
+ mask=q_load_mask,
479
+ )
480
+ acc = tl.zeros([SEQ_BLOCK, triton.next_power_of_2(V_D_HEAD)], dtype=tl.float32)
481
+ lse_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
482
+ m_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
483
+
484
+ for s in range(0, seq_block_id + 1, 1):
485
+ kv_seq_offsets = s * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
486
+ kv_seq_mask = kv_seq_offsets < seq_len
487
+ k_load_mask = kv_seq_mask[:, None] * q_dhead_mask[None, :]
488
+
489
+ k = tl.load(
490
+ k_ptr
491
+ + kv_batch_offset * K_D_HEAD
492
+ + kv_seq_offsets[:, None] * N_KV_HEADS * K_D_HEAD
493
+ + k_head_offset
494
+ + q_dhead_offsets[None, :],
495
+ mask=k_load_mask,
496
+ )
497
+ qk = tl.zeros([SEQ_BLOCK, SEQ_BLOCK], dtype=tl.float32)
498
+ qk += tl.dot(q, k.trans())
499
+ # causal mask
500
+ qk = tl.where(seq_offsets[:, None] >= kv_seq_offsets[None, :], qk, float("-inf"))
501
+ qk *= softmax_scale
502
+ # rowmax
503
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
504
+ p = tl.exp(qk - m_ij[:, None]) # [S,S]
505
+ v = tl.load(
506
+ v_ptr
507
+ + kv_batch_offset * V_D_HEAD
508
+ + kv_seq_offsets[:, None] * N_KV_HEADS * V_D_HEAD
509
+ + v_head_offset
510
+ + v_dhead_offsets[None, :],
511
+ mask=kv_seq_mask[:, None] * v_dhead_mask[None, :],
512
+ )
513
+
514
+ l_ij = tl.sum(p, 1)
515
+ acc_scale = tl.exp(m_i - m_ij)
516
+ acc = acc * acc_scale[:, None]
517
+ p = p.to(v.dtype)
518
+ acc += tl.dot(p, v)
519
+ m_i = m_ij
520
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
521
+ lse_i = m_ij + tl.log(l_i_new)
522
+
523
+ o_scale = tl.exp(m_i - lse_i)
524
+
525
+ acc = acc * o_scale[:, None]
526
+
527
+ tl.store(
528
+ o_ptr
529
+ + batch_id * seq_len * N_HEADS * V_D_HEAD
530
+ + seq_offsets[:, None] * N_HEADS * V_D_HEAD
531
+ + head_id * V_D_HEAD
532
+ + v_dhead_offsets[None, :],
533
+ acc,
534
+ mask=seq_mask[:, None] * v_dhead_mask[None, :],
535
+ )
536
+
537
+ # Write back to kv-caches
538
+
539
+ ks = tl.load(
540
+ k_ptr
541
+ + kv_batch_offset * K_D_HEAD
542
+ + seq_offsets[:, None] * N_KV_HEADS * K_D_HEAD
543
+ + k_head_offset
544
+ + q_dhead_offsets[None, :],
545
+ mask=seq_mask[:, None] * q_dhead_mask[None, :],
546
+ )
547
+ vs = tl.load(
548
+ v_ptr
549
+ + kv_batch_offset * V_D_HEAD
550
+ + seq_offsets[:, None] * N_KV_HEADS * V_D_HEAD
551
+ + v_head_offset
552
+ + v_dhead_offsets[None, :],
553
+ mask=seq_mask[:, None] * v_dhead_mask[None, :],
554
+ )
555
+ # cache is [bsnd]
556
+ k_cache_offset = (
557
+ batch_id * N_KV_HEADS * MAX_SEQ_LENGTH * K_D_HEAD
558
+ + seq_offsets[:, None] * K_D_HEAD * N_KV_HEADS
559
+ + k_head_offset
560
+ + q_dhead_offsets[None, :]
561
+ )
562
+
563
+ v_cache_offset = (
564
+ batch_id * N_KV_HEADS * MAX_SEQ_LENGTH * V_D_HEAD
565
+ + seq_offsets[:, None] * V_D_HEAD * N_KV_HEADS
566
+ + v_head_offset
567
+ + v_dhead_offsets[None, :]
568
+ )
569
+ tl.store(k_cache_ptr + k_cache_offset, ks, seq_mask[:, None] * q_dhead_mask[None, :])
570
+ tl.store(v_cache_ptr + v_cache_offset, vs, seq_mask[:, None] * v_dhead_mask[None, :])
571
+
572
+
573
+ @triton.jit
574
+ def context_attention_kv_flattened(
575
+ q_ptr, # [b*s,nd]
576
+ seq_len_ptr, # [b] # length of each sequence in a batch
577
+ seq_start_indices_ptr, # [b] # start indices of a sequence in flattened q/k/v.
578
+ k_cache_ptr, # [bsnd]
579
+ v_cache_ptr, # [bsnd]
580
+ input_pos_ptr, # [b] # specifies the location in the sequence where kv must be written back.
581
+ cache_loc_ptr, # [b] # location of the sequence in the cache.
582
+ o_ptr,
583
+ softmax_scale: tl.constexpr,
584
+ N_HEADS: tl.constexpr, # Number of heads
585
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
586
+ Q_D_HEAD: tl.constexpr, # Dimension of each query head.
587
+ V_D_HEAD: tl.constexpr, # Dimension of each value head.
588
+ SEQ_BLOCK: tl.constexpr,
589
+ MAX_SEQ_LENGTH: tl.constexpr,
590
+ ):
591
+ """Kernel for context phase.
592
+
593
+ Assumes that kv caches have been updated.
594
+ Assuming QKV layout: [b*s,n,d]
595
+ """
596
+ batch_id = tl.program_id(axis=0)
597
+ head_id = tl.program_id(axis=1)
598
+ seq_block_id = tl.program_id(axis=2)
599
+
600
+ # Each program is responsible for a block of tokens in a single batch.
601
+ seq_start_index = tl.load(seq_start_indices_ptr + batch_id)
602
+ seq_len = tl.load(seq_len_ptr + batch_id)
603
+ K_D_HEAD: tl.constexpr = Q_D_HEAD
604
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
605
+
606
+ # cache is [bsnd]
607
+ # cache_loc_ptr stores the batch index for the sequences provided to the kernel.
608
+ cache_loc = tl.load(cache_loc_ptr + batch_id)
609
+
610
+ cache_batch_offset = cache_loc * N_KV_HEADS * MAX_SEQ_LENGTH
611
+ cache_head_offset = head_id // HEAD_RATIO
612
+
613
+ q_dhead_offsets = tl.arange(0, triton.next_power_of_2(Q_D_HEAD))
614
+ q_dhead_mask = q_dhead_offsets < Q_D_HEAD
615
+
616
+ v_dhead_offsets = tl.arange(0, triton.next_power_of_2(V_D_HEAD))
617
+ v_dhead_mask = v_dhead_offsets < V_D_HEAD
618
+
619
+ seq_offsets = seq_block_id * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
620
+ seq_mask = seq_offsets < seq_len
621
+
622
+ # Q will stay in SRAM
623
+ q = tl.load(
624
+ q_ptr
625
+ + seq_start_index * N_HEADS * Q_D_HEAD
626
+ + seq_offsets[:, None] * N_HEADS * Q_D_HEAD
627
+ + head_id * Q_D_HEAD
628
+ + q_dhead_offsets[None, :],
629
+ mask=seq_mask[:, None] * q_dhead_mask[None, :],
630
+ )
631
+
632
+ acc = tl.zeros([SEQ_BLOCK, triton.next_power_of_2(V_D_HEAD)], dtype=tl.float32)
633
+ lse_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
634
+ m_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
635
+
636
+ # Loop over the entire KV-history
637
+ # input_pos_ptr stores the location at which kv must be written back for the given batch.
638
+ kv_position = tl.load(input_pos_ptr + batch_id)
639
+ num_blocks = (kv_position + seq_len + SEQ_BLOCK - 1) // SEQ_BLOCK
640
+ for s in range(0, num_blocks + 1, 1):
641
+ kv_seq_offsets = s * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
642
+ kv_seq_mask = kv_seq_offsets < (kv_position + seq_len)
643
+
644
+ k = tl.load(
645
+ k_cache_ptr
646
+ + cache_batch_offset * K_D_HEAD
647
+ + kv_seq_offsets[:, None] * K_D_HEAD * N_KV_HEADS
648
+ + cache_head_offset * K_D_HEAD
649
+ + q_dhead_offsets[None, :],
650
+ mask=kv_seq_mask[:, None] * q_dhead_mask[None, :],
651
+ )
652
+ qk = tl.zeros([SEQ_BLOCK, SEQ_BLOCK], dtype=tl.float32)
653
+ qk += tl.dot(q, k.trans())
654
+ qk = tl.where(
655
+ (seq_offsets[:, None] + kv_position) >= kv_seq_offsets[None, :], qk, float("-inf")
656
+ )
657
+ qk *= softmax_scale
658
+ # rowmax
659
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
660
+ p = tl.exp(qk - m_ij[:, None])
661
+ v = tl.load(
662
+ v_cache_ptr
663
+ + cache_batch_offset * V_D_HEAD
664
+ + kv_seq_offsets[:, None] * V_D_HEAD * N_KV_HEADS
665
+ + cache_head_offset * V_D_HEAD
666
+ + v_dhead_offsets[None, :],
667
+ mask=kv_seq_mask[:, None] * v_dhead_mask[None, :],
668
+ )
669
+
670
+ l_ij = tl.sum(p, 1)
671
+ acc_scale = tl.exp(m_i - m_ij)
672
+ acc = acc * acc_scale[:, None]
673
+ p = p.to(v.dtype)
674
+ acc += tl.dot(p, v)
675
+ m_i = m_ij
676
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
677
+ lse_i = m_ij + tl.log(l_i_new)
678
+
679
+ o_scale = tl.exp(m_i - lse_i)
680
+
681
+ acc = acc * o_scale[:, None]
682
+
683
+ tl.store(
684
+ o_ptr
685
+ + seq_start_index * N_HEADS * V_D_HEAD
686
+ + seq_offsets[:, None] * N_HEADS * V_D_HEAD
687
+ + head_id * V_D_HEAD
688
+ + v_dhead_offsets[None, :],
689
+ acc,
690
+ mask=seq_mask[:, None] * v_dhead_mask[None, :],
691
+ )
692
+
693
+
694
+ @triton.jit
695
+ def update_kv_cache_rope_fusion(
696
+ q_ptr, # [B*S, N, D]
697
+ k_ptr, # [B*S, N, D]
698
+ v_ptr, # [B*S, N, D]
699
+ seq_len_ptr, # [b] # length of each sequence in a batch
700
+ seq_start_indices_ptr, # [b] # start indices of a sequence in flattened q/k/v.
701
+ q_rope_ptr, # [B*S, N, D], roped q result
702
+ k_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
703
+ v_cache_ptr, # [MAX_BATCH_SIZE, MAX_SEQ_LEN, N_HEADS, D_HEAD]
704
+ input_pos_ptr, # Specifies the sequence index in the caches at which to write the provided kv
705
+ cache_loc_ptr, # Specifies the batch index for each of the input sequences
706
+ f_ptr, # [MAX_SEQ_LEN, D_HEAD//2, 2] # frequencies for rope embadding.
707
+ MAX_SEQ_LENGTH: tl.constexpr,
708
+ N_HEADS: tl.constexpr,
709
+ N_KV_HEADS: tl.constexpr,
710
+ D_HEAD: tl.constexpr,
711
+ SEQ_BLOCK: tl.constexpr,
712
+ HEAD_BLOCK_SIZE: tl.constexpr, # pad to 16 if HEAD_RATIO is < 16 to invoke tensor cores.
713
+ GENERATE_ONLY: tl.constexpr,
714
+ ):
715
+ """Fuse q and k rope with update_kv_cache kernel.
716
+
717
+ The input is interleaved as [2, D//2] in D_HEAD dim.
718
+ Update q_rope with the post-rope-embadding q values.
719
+ Update k_cache with the post-rope-embadding k values.
720
+ For rope computation, q and k need to load and store in tensors pair of 2 * [D//2].
721
+ Update v_cache with v.
722
+ """
723
+ batch_id = tl.program_id(axis=0)
724
+ kv_head_id = tl.program_id(axis=1)
725
+ seq_block_id = tl.program_id(axis=2)
726
+
727
+ # Each program is responsible for a block of tokens in a single batch.
728
+ if GENERATE_ONLY:
729
+ seq_start_index = batch_id
730
+ seq_len: tl.constexpr = 1
731
+ else:
732
+ seq_start_index = tl.load(seq_start_indices_ptr + batch_id)
733
+ seq_len = tl.load(seq_len_ptr + batch_id)
734
+
735
+ # cache is [bsnd]
736
+ # cache_loc_ptr stores the batch index for the sequences provided to the kernel.
737
+ cache_loc = tl.load(cache_loc_ptr + batch_id)
738
+
739
+ kv_position = tl.load(input_pos_ptr + batch_id)
740
+
741
+ cache_batch_offset = cache_loc * N_KV_HEADS * MAX_SEQ_LENGTH * D_HEAD
742
+ cache_head_offset = kv_head_id * D_HEAD
743
+
744
+ # Assuming D_HEAD is a power of 2
745
+ dhead_offsets = tl.arange(0, D_HEAD)
746
+ dhead_mask = dhead_offsets < D_HEAD
747
+
748
+ seq_offsets = seq_block_id * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
749
+ seq_mask = seq_offsets < seq_len
750
+
751
+ load_mask = seq_mask[:, None] * dhead_mask[None, :]
752
+
753
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS # This needs to be a power-of-2
754
+ q_head_offsets = kv_head_id * HEAD_RATIO + tl.arange(0, HEAD_BLOCK_SIZE)
755
+ q_head_mask = q_head_offsets < (kv_head_id * HEAD_RATIO + HEAD_RATIO)
756
+
757
+ q_batch_offset = seq_start_index * N_HEADS * D_HEAD
758
+
759
+ kv_batch_offset = seq_start_index * N_KV_HEADS * D_HEAD
760
+ kv_head_offset = cache_head_offset
761
+
762
+ D2: tl.constexpr = D_HEAD // 2
763
+ # input is interleaved as [2, D//2] in dim [D_HEAD].
764
+ d2_offsets = tl.arange(0, D2)
765
+ dhead_offsets1 = d2_offsets
766
+ dhead_offsets2 = d2_offsets + D2
767
+ d2_mask = dhead_offsets2 < D_HEAD
768
+ d2_load_mask = seq_mask[:, None] * d2_mask[None, :]
769
+
770
+ # offsets of [bsn]
771
+ q_offsets_base = (
772
+ q_batch_offset
773
+ + seq_offsets[:, None, None] * N_HEADS * D_HEAD
774
+ + q_head_offsets[None, :, None] * D_HEAD
775
+ )
776
+ q_offsets1 = q_offsets_base + dhead_offsets1[None, None, :]
777
+ q_offsets2 = q_offsets_base + dhead_offsets2[None, None, :]
778
+ q_mask = d2_load_mask[:, None, :] * q_head_mask[None, :, None]
779
+
780
+ q1 = tl.load(q_ptr + q_offsets1, mask=q_mask).to(tl.float32)
781
+ q2 = tl.load(q_ptr + q_offsets2, mask=q_mask).to(tl.float32)
782
+
783
+ k_offsets_base = kv_batch_offset + seq_offsets[:, None] * N_KV_HEADS * D_HEAD + kv_head_offset
784
+ k_offsets1 = k_offsets_base + dhead_offsets1[None, :]
785
+ k_offsets2 = k_offsets_base + dhead_offsets2[None, :]
786
+
787
+ k1 = tl.load(k_ptr + k_offsets1, mask=d2_load_mask).to(tl.float32)
788
+ k2 = tl.load(k_ptr + k_offsets2, mask=d2_load_mask).to(tl.float32)
789
+
790
+ # -----------------------------------
791
+ # torch version sin/cos
792
+ # cos and sin values are interleaved in frequencies tensor.
793
+ f_offsets = seq_offsets[:, None] * D2 + d2_offsets[None, :]
794
+ cos_ref = tl.load(f_ptr + kv_position * D_HEAD + f_offsets * 2, mask=d2_load_mask).to(
795
+ dtype=tl.float32
796
+ )
797
+ sin_ref = tl.load(f_ptr + kv_position * D_HEAD + f_offsets * 2 + 1, mask=d2_load_mask).to(
798
+ dtype=tl.float32
799
+ )
800
+
801
+ qs1 = cos_ref[:, None, :] * q1 - sin_ref[:, None, :] * q2
802
+ qs2 = sin_ref[:, None, :] * q1 + cos_ref[:, None, :] * q2
803
+
804
+ tl.store(q_rope_ptr + q_offsets1, qs1, mask=q_mask)
805
+ tl.store(q_rope_ptr + q_offsets2, qs2, mask=q_mask)
806
+
807
+ ks1 = cos_ref * k1 - sin_ref * k2
808
+ ks2 = sin_ref * k1 + cos_ref * k2
809
+
810
+ # Write back to kv-caches
811
+ vs = tl.load(
812
+ v_ptr
813
+ + kv_batch_offset
814
+ + seq_offsets[:, None] * N_KV_HEADS * D_HEAD
815
+ + kv_head_offset
816
+ + dhead_offsets[None, :],
817
+ mask=load_mask,
818
+ )
819
+
820
+ kv_writeback_seq_offsets = seq_offsets + kv_position
821
+
822
+ cache_offset_base = (
823
+ cache_batch_offset
824
+ + kv_writeback_seq_offsets[:, None] * D_HEAD * N_KV_HEADS
825
+ + cache_head_offset
826
+ )
827
+
828
+ k_cache_offset1 = cache_offset_base + dhead_offsets1[None, :]
829
+ k_cache_offset2 = cache_offset_base + dhead_offsets2[None, :]
830
+ tl.store(k_cache_ptr + k_cache_offset1, ks1, mask=d2_load_mask)
831
+ tl.store(k_cache_ptr + k_cache_offset2, ks2, mask=d2_load_mask)
832
+
833
+ v_cache_offset = cache_offset_base + dhead_offsets[None, :]
834
+ tl.store(v_cache_ptr + v_cache_offset, vs, load_mask)
835
+
836
+
837
+
838
+ """
839
+ Kernels based on paged KV Cache.
840
+ Parameter infos:
841
+ tensors:
842
+ - q: [b*s, n, d], flattened queries.
843
+ - k/v: [b*s, n, d], flattened key/value.
844
+ - seq_len: [b], length of each sequence in the batch.
845
+ `seq_len` can be 1 (generate) or larger (context).
846
+ - seq_start: [b], start index of each sequence in b*s dim of q/k/v.
847
+ - k_cache/v_cache: [num_pages, PAGE_SIZE, n, d], paged KV Cache.
848
+ New-coming k/v is split into small group of PAGE_SIZE, and then
849
+ mapped to incontinuous memory in KV Cache.
850
+ - page_table: [b, max_num_pages_per_seq], mapping logic of each sequence.
851
+ - cache_loc: [b], mapping logic of `batch_id` in q/k/v to index in `page_table`.
852
+ - cache_len: [b], existing cached k/v length of each sequence.
853
+
854
+ constexpr:
855
+ - N_HEADS/N_KV_HEADS: shape of dim [n] in q or k/v.
856
+ - D_HEAD: shape of dim [d] in q/k/v.
857
+ Assuming power of 2.
858
+ - SEQ_BLOCK: block size to split dim [s].
859
+ Assuming power of 2.
860
+ Split k/v in update kernel and split q in context/generate kernel.
861
+ - MAX_SEQ_LENGTH: seq_len <= MAX_SEQ_LENGTH.
862
+ - PAGE_SIZE: shape of each kv cache page,
863
+ Assuming power of 2 and SEQ_BLOCK % PAGE_SIZE = 0.
864
+ - PAGE_TABLE_STIDE: stride of dim [b] in `page_table`.
865
+
866
+ KV Cache access logic in update kernel:
867
+ 1. batch_id i access k[seq_start[i] : seq_start[i] + seq_len[i]]
868
+ and can be split into pages [a:b] in the sequence.
869
+ 2. Look up cache_len[i] to find if the sequence has cached k/v.
870
+ 3. Look up page_table[cache_loc[i], cache_len[i] + a : cache_len[i] + b]
871
+ to get the corresponding pages in the k_cache, with result [c:d].
872
+ 4. Then update k_cache[c:d] with the k value.
873
+
874
+ """
875
+
876
+
877
+ @triton.jit
878
+ def update_paged_kv_cache(
879
+ k_ptr, # [B*S, N, D]
880
+ v_ptr, # [B*S, N, D]
881
+ seq_len_ptr, # [b] # length of each sequence in a batch
882
+ seq_start_indices_ptr, # [b] # start indices of a sequence in flattened q/k/v.
883
+ k_cache_ptr, # [num_pages, page_size, n, d]
884
+ v_cache_ptr, # [num_pages, page_size, n, d]
885
+ cache_loc_ptr, # [b] # index of the sequence in the page table.
886
+ cache_len_ptr, # [b] # length of the sequence already in kv cache.
887
+ page_table_ptr, # [b, max_num_pages_per_seq] # loc of the block page in the cache.
888
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
889
+ D_HEAD: tl.constexpr, # Dimension of each head.
890
+ SEQ_BLOCK: tl.constexpr,
891
+ MAX_SEQ_LENGTH: tl.constexpr,
892
+ PAGE_SIZE: tl.constexpr,
893
+ PAGE_TABLE_STRIDE: tl.constexpr,
894
+ GENERATE_ONLY: tl.constexpr,
895
+ ):
896
+ batch_id = tl.program_id(axis=0)
897
+ head_id = tl.program_id(axis=1)
898
+ seq_block_id = tl.program_id(axis=2)
899
+
900
+ # Each program is responsible for a block of tokens in a single batch.
901
+ if GENERATE_ONLY:
902
+ seq_start_index = batch_id
903
+ seq_len: tl.constexpr = 1
904
+ else:
905
+ seq_start_index = tl.load(seq_start_indices_ptr + batch_id)
906
+ seq_len = tl.load(seq_len_ptr + batch_id)
907
+
908
+ cache_len = tl.load(cache_len_ptr + batch_id)
909
+
910
+ # cache is [num_pages, page_size, n, d]
911
+ # cache_loc_ptr stores the batch index for the sequences provided to the kernel.
912
+ cache_loc = tl.load(cache_loc_ptr + batch_id)
913
+ cache_head_offset = head_id * D_HEAD
914
+
915
+ # Assuming D_HEAD is a power of 2
916
+ dhead_offsets = tl.arange(0, D_HEAD)
917
+ dhead_mask = dhead_offsets < D_HEAD
918
+
919
+ seq_offsets = seq_block_id * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
920
+ seq_mask = seq_offsets < seq_len
921
+
922
+ load_mask = seq_mask[:, None] * dhead_mask[None, :]
923
+
924
+ kv_batch_offset = seq_start_index * N_KV_HEADS * D_HEAD
925
+ kv_head_offset = cache_head_offset
926
+
927
+ # Write back to kv-caches
928
+ ks = tl.load(
929
+ k_ptr
930
+ + kv_batch_offset
931
+ + seq_offsets[:, None] * N_KV_HEADS * D_HEAD
932
+ + kv_head_offset
933
+ + dhead_offsets[None, :],
934
+ mask=load_mask,
935
+ )
936
+ vs = tl.load(
937
+ v_ptr
938
+ + kv_batch_offset
939
+ + seq_offsets[:, None] * N_KV_HEADS * D_HEAD
940
+ + kv_head_offset
941
+ + dhead_offsets[None, :],
942
+ mask=load_mask,
943
+ )
944
+
945
+ # assuming SEQ_BLOCK can be divided by PAGE_SIZE and PAGE_SIZE is a power of 2.
946
+ SEQ_BLOCK_PAGE: tl.constexpr = SEQ_BLOCK // PAGE_SIZE
947
+ MAX_NUM_PAGES: tl.constexpr = (MAX_SEQ_LENGTH + PAGE_SIZE - 1) // PAGE_SIZE
948
+ # cache_len // PAGE_SIZE means history pages
949
+ # if decode sequence, then seq_len = 1 and only seq_block_id = 0 works,
950
+ kv_pages = seq_block_id * SEQ_BLOCK_PAGE + tl.arange(0, SEQ_BLOCK_PAGE) + cache_len // PAGE_SIZE
951
+ cache_pages = tl.load(
952
+ page_table_ptr + cache_loc * PAGE_TABLE_STRIDE + kv_pages, mask=kv_pages < MAX_NUM_PAGES
953
+ )
954
+
955
+ page_offsets = tl.arange(0, PAGE_SIZE)
956
+ # shape [SEQ_BLOCK], means [cache_pages, page_offsets]
957
+ cache_seq_offset = tl.reshape(
958
+ cache_pages[:, None] * PAGE_SIZE + page_offsets[None, :], [SEQ_BLOCK]
959
+ )
960
+ # write offset inside the page
961
+ cache_seq_offset += cache_len % PAGE_SIZE
962
+
963
+ cache_offsets = (
964
+ cache_seq_offset[:, None] * N_KV_HEADS * D_HEAD + kv_head_offset + dhead_offsets[None, :]
965
+ )
966
+ tl.store(k_cache_ptr + cache_offsets, ks, load_mask)
967
+ tl.store(v_cache_ptr + cache_offsets, vs, load_mask)
968
+
969
+
970
+ # TODO: Write a doc describing the 2 stage algorithm
971
+ @triton.jit
972
+ def attention_kv_paged_stage1(
973
+ q_ptr, # [Batch, 1, N_HEADS, D_HEAD]
974
+ k_cache_ptr, # [NUM_PAGES, PAGE_SIZE, N_HEADS, D_HEAD]
975
+ v_cache_ptr, # [NUM_PAGES, PAGE_SIZE, N_HEADS, D_HEAD]
976
+ cache_loc_ptr, # [Batch] # Specifies the batch index for each of the generate tokens.
977
+ page_table_ptr, # [Batch, num_pages_per_seq]
978
+ cache_len_ptr, # [Batch] # Number of tokens in kv cache.
979
+ output_values_ptr, # [Batch, N_HEADS, num_blocks, D_HEAD]
980
+ output_logsumexp_ptr, # [Batch, N_HEADS, num_blocks]
981
+ num_blocks,
982
+ MAX_SEQ_LEN: tl.constexpr, # Maximum supported sequence length
983
+ N_HEADS: tl.constexpr, # Number of heads
984
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
985
+ D_HEAD: tl.constexpr, # Dimension of each head.
986
+ # Block size used for tiling the sequence dim.
987
+ SEQ_BLOCK_SIZE: tl.constexpr,
988
+ PAGE_SIZE: tl.constexpr,
989
+ PAGE_TABLE_STRIDE: tl.constexpr,
990
+ ):
991
+ """Attention kernel to be used during the generate phase.
992
+
993
+ Uses flash decoding.
994
+ KV-cache layout is assumed to be [Batch, Head, Seq, Dim]
995
+ 1. Fetch the K-cache from 0 to input_pos
996
+ 2. Fetch the V-cache from 0 to input_pos
997
+ 3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len]
998
+ 4. S = softmax(A)
999
+ 5. O = S*V [1, seq_len] * [1, seq_len, D_HEAD] -> [1, D_HEAD]
1000
+ """
1001
+ # Assume KV-cache layout: [Batch, Head, Seq, Dim]
1002
+ # A program is responsible for 1 batch, 1 head and a block of sequences.
1003
+ batch_id = tl.program_id(axis=0)
1004
+ head_id = tl.program_id(axis=1)
1005
+ seq_block_id = tl.program_id(axis=2)
1006
+
1007
+ SEQ_BLOCK_PAGE: tl.constexpr = SEQ_BLOCK_SIZE // PAGE_SIZE
1008
+ MAX_NUM_PAGES: tl.constexpr = MAX_SEQ_LEN // PAGE_SIZE
1009
+
1010
+ cache_loc = tl.load(cache_loc_ptr + batch_id)
1011
+ seq_len = tl.load(cache_len_ptr + batch_id)
1012
+ # Offsets for the block of sequences this program processes.
1013
+ seq_start_pos = seq_block_id * SEQ_BLOCK_SIZE
1014
+
1015
+ if seq_start_pos > seq_len:
1016
+ return
1017
+ seq_offsets = seq_start_pos + tl.arange(0, SEQ_BLOCK_SIZE)
1018
+ seq_mask = seq_offsets <= seq_len
1019
+ # Assuming D_HEAD is a power of 2
1020
+ dhead_offsets = tl.arange(0, D_HEAD)
1021
+ dhead_mask = dhead_offsets < D_HEAD
1022
+
1023
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
1024
+ cache_head_offset = (head_id // HEAD_RATIO) * D_HEAD
1025
+
1026
+ sm_scale: tl.constexpr = 1 / (D_HEAD**0.5)
1027
+
1028
+ # Program loads the entire Q for the head assigned to it.
1029
+ # [D_HEAD]
1030
+ q_batch_offset = batch_id * N_HEADS * D_HEAD
1031
+ q_head_offset = head_id * D_HEAD
1032
+ q = tl.load(q_ptr + q_batch_offset + q_head_offset + dhead_offsets)
1033
+
1034
+ kv_mask = seq_mask[:, None] * dhead_mask[None, :]
1035
+
1036
+ kv_pages = seq_block_id * SEQ_BLOCK_PAGE + tl.arange(0, SEQ_BLOCK_PAGE)
1037
+ cache_pages = tl.load(
1038
+ page_table_ptr + cache_loc * PAGE_TABLE_STRIDE + kv_pages, mask=kv_pages < MAX_NUM_PAGES
1039
+ )
1040
+
1041
+ page_offsets = tl.arange(0, PAGE_SIZE)
1042
+ # shape [SEQ_BLOCK], means [cache_pages, page_offsets]
1043
+ # token offsets in the paged kv cache
1044
+ cache_seq_offset = tl.reshape(
1045
+ cache_pages[:, None] * PAGE_SIZE + page_offsets[None, :], [SEQ_BLOCK_SIZE]
1046
+ )
1047
+
1048
+ cache_offsets = (
1049
+ cache_seq_offset[:, None] * N_KV_HEADS * D_HEAD + cache_head_offset + dhead_offsets[None, :]
1050
+ )
1051
+
1052
+ k = tl.load(k_cache_ptr + cache_offsets, mask=kv_mask)
1053
+ v = tl.load(v_cache_ptr + cache_offsets, mask=kv_mask)
1054
+
1055
+ # Note: check the output precision of the sum.
1056
+ # compute q*K^T
1057
+ # [D_HEAD] * [seq_block, D_HEAD], sum along axis 1
1058
+ attn = tl.sum(q[None, :] * k, axis=1) # [seq_block]
1059
+ attn = attn.to(tl.float32)
1060
+ attn *= sm_scale
1061
+ max_attn = tl.max(attn)
1062
+ # Set to -inf attn values where mask is not set. This forces exp(attn) to 0.
1063
+ attn = tl.where(seq_mask, attn, float("-inf"))
1064
+ exp_attn = tl.exp(attn - max_attn)
1065
+
1066
+ sumexp = tl.sum(exp_attn, axis=0) # scalar.
1067
+
1068
+ # [seq_len] * [seq_len, D_HEAD], sum along axis 0
1069
+ output = tl.sum(exp_attn[:, None] * v, axis=0) # [D_HEAD]
1070
+
1071
+ output = output / sumexp
1072
+
1073
+ # We store the log-sum-exp after removing the max.
1074
+ logsumexp = tl.log(sumexp) + max_attn
1075
+ # when seq_mask is all false, max_attn will be -inf and sumexp is zero
1076
+
1077
+ tl.store(
1078
+ output_values_ptr
1079
+ + batch_id * N_HEADS * D_HEAD * num_blocks
1080
+ + head_id * D_HEAD * num_blocks
1081
+ + seq_block_id * D_HEAD
1082
+ + dhead_offsets,
1083
+ output,
1084
+ )
1085
+ tl.store(
1086
+ output_logsumexp_ptr
1087
+ + batch_id * N_HEADS * num_blocks
1088
+ + head_id * num_blocks
1089
+ + seq_block_id,
1090
+ logsumexp,
1091
+ )
1092
+
1093
+
1094
+ @triton.jit
1095
+ def context_attention_kv_paged(
1096
+ q_ptr, # [b*s,nd]
1097
+ seq_len_ptr, # [b] # length of each sequence in a batch
1098
+ seq_start_ptr, # [b] # start indices of a sequence in flattened q/k/v.
1099
+ k_cache_ptr, # [num_pages, page_size, n, d]
1100
+ v_cache_ptr, # [num_pages, page_size, n, d]
1101
+ cache_loc_ptr, # [b] # index of the sequence in the page table.
1102
+ cache_len_ptr, # [Batch] # Number of tokens in kv cache.
1103
+ page_table_ptr, # [b, max_num_pages_per_seq] # loc of the block page in the cache.
1104
+ softmax_scale,
1105
+ o_ptr,
1106
+ N_HEADS: tl.constexpr, # Number of heads
1107
+ N_KV_HEADS: tl.constexpr, # Number of KV heads.
1108
+ D_HEAD: tl.constexpr, # Dimension of each head.
1109
+ SEQ_BLOCK: tl.constexpr,
1110
+ MAX_SEQ_LENGTH: tl.constexpr,
1111
+ PAGE_SIZE: tl.constexpr,
1112
+ PAGE_TABLE_STRIDE: tl.constexpr,
1113
+ ):
1114
+ """Kernel for context phase.
1115
+
1116
+ Fuses rope
1117
+ Assuming:
1118
+ 1. Self-attention [seqlen(Q) == seqlen(K)]
1119
+ 2. Causal attention
1120
+ 3. QKV layout: [b*s,n,d]
1121
+ """
1122
+ batch_id = tl.program_id(axis=0)
1123
+ head_id = tl.program_id(axis=1)
1124
+ seq_block_id = tl.program_id(axis=2)
1125
+
1126
+ # Each program is responsible for a block of tokens in a single batch.
1127
+ seq_start_index = tl.load(seq_start_ptr + batch_id)
1128
+ seq_len = tl.load(seq_len_ptr + batch_id)
1129
+
1130
+ HEAD_RATIO: tl.constexpr = N_HEADS // N_KV_HEADS
1131
+
1132
+ # assuming SEQ_BLOCK can be divided by PAGE_SIZE and PAGE_SIZE is a power of 2.
1133
+ SEQ_BLOCK_PAGE: tl.constexpr = SEQ_BLOCK // PAGE_SIZE
1134
+ MAX_NUM_PAGES: tl.constexpr = (MAX_SEQ_LENGTH + PAGE_SIZE - 1) // PAGE_SIZE
1135
+
1136
+ # cache is [num_pages, page_size, n, d]
1137
+ # cache_loc_ptr stores the batch index for the sequences provided to the kernel.
1138
+ cache_loc = tl.load(cache_loc_ptr + batch_id)
1139
+ table_batch_offset = cache_loc * PAGE_TABLE_STRIDE
1140
+
1141
+ # Assuming D_HEAD is a power of 2
1142
+ dhead_offsets = tl.arange(0, D_HEAD)
1143
+ dhead_mask = dhead_offsets < D_HEAD
1144
+
1145
+ seq_offsets = tl.arange(0, SEQ_BLOCK)
1146
+ q_seq_offsets = seq_block_id * SEQ_BLOCK + seq_offsets
1147
+ seq_mask = q_seq_offsets < seq_len
1148
+
1149
+ load_mask = seq_mask[:, None] * dhead_mask[None, :]
1150
+
1151
+ q_batch_offset = seq_start_index * N_HEADS * D_HEAD
1152
+ q_head_offset = head_id * D_HEAD
1153
+ cache_head_offset = (head_id // HEAD_RATIO) * D_HEAD
1154
+
1155
+ # Q will stay in SRAM
1156
+ q = tl.load(
1157
+ q_ptr
1158
+ + q_batch_offset
1159
+ + q_seq_offsets[:, None] * N_HEADS * D_HEAD
1160
+ + q_head_offset
1161
+ + dhead_offsets[None, :],
1162
+ mask=load_mask,
1163
+ )
1164
+ acc = tl.zeros([SEQ_BLOCK, D_HEAD], dtype=tl.float32)
1165
+ lse_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
1166
+ m_i = tl.zeros([SEQ_BLOCK], dtype=tl.float32) - float("inf")
1167
+
1168
+ cache_len = tl.load(cache_len_ptr + batch_id)
1169
+ total_len = cache_len + seq_len
1170
+ num_blocks = (total_len + SEQ_BLOCK - 1) // SEQ_BLOCK
1171
+ for s in range(0, num_blocks + 1, 1):
1172
+ kv_pages = s * SEQ_BLOCK_PAGE + tl.arange(0, SEQ_BLOCK_PAGE)
1173
+ cache_pages = tl.load(
1174
+ page_table_ptr + table_batch_offset + kv_pages, mask=kv_pages < MAX_NUM_PAGES
1175
+ )
1176
+
1177
+ page_offsets = tl.arange(0, PAGE_SIZE)
1178
+ # shape [SEQ_BLOCK], means [cache_pages, page_offsets]
1179
+ # physical token offsets in the paged kv cache
1180
+ cache_seq_offset = tl.reshape(
1181
+ cache_pages[:, None] * PAGE_SIZE + page_offsets[None, :], [SEQ_BLOCK]
1182
+ )
1183
+ cache_offsets = (
1184
+ cache_seq_offset[:, None] * N_KV_HEADS * D_HEAD
1185
+ + cache_head_offset
1186
+ + dhead_offsets[None, :]
1187
+ )
1188
+
1189
+ # logical kv tokens offsets
1190
+ kv_seq_offsets = s * SEQ_BLOCK + seq_offsets
1191
+ kv_seq_mask = kv_seq_offsets < total_len
1192
+ kv_load_mask = kv_seq_mask[:, None] * dhead_mask[None, :]
1193
+
1194
+ k = tl.load(k_cache_ptr + cache_offsets, mask=kv_load_mask)
1195
+ qk = tl.zeros([SEQ_BLOCK, SEQ_BLOCK], dtype=tl.float32)
1196
+ qk += tl.dot(q, k.trans())
1197
+ # causal mask, need to use kv_seq_offsets
1198
+ qk = tl.where(
1199
+ (q_seq_offsets[:, None] + cache_len) >= kv_seq_offsets[None, :], qk, float("-inf")
1200
+ )
1201
+
1202
+ qk *= softmax_scale
1203
+ # rowmax
1204
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
1205
+ p = tl.exp(qk - m_ij[:, None])
1206
+ v = tl.load(v_cache_ptr + cache_offsets, mask=kv_load_mask)
1207
+
1208
+ l_ij = tl.sum(p, 1)
1209
+ acc_scale = tl.exp(m_i - m_ij)
1210
+ acc = acc * acc_scale[:, None]
1211
+ p = p.to(v.dtype)
1212
+ acc += tl.dot(p, v)
1213
+ m_i = m_ij
1214
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
1215
+ lse_i = m_ij + tl.log(l_i_new)
1216
+
1217
+ o_scale = tl.exp(m_i - lse_i)
1218
+
1219
+ acc = acc * o_scale[:, None]
1220
+
1221
+ tl.store(
1222
+ o_ptr
1223
+ + q_batch_offset
1224
+ + q_seq_offsets[:, None] * N_HEADS * D_HEAD
1225
+ + q_head_offset
1226
+ + dhead_offsets[None, :],
1227
+ acc,
1228
+ mask=load_mask,
1229
+ )
1230
+
1231
+
1232
+
1233
+ @dataclass
1234
+ class PositionalEmbeddingConfig:
1235
+ """A dataclass to hold positional embedding information."""
1236
+
1237
+ mode: Optional[Literal["rope"]] = None
1238
+ rope_theta: float = 10000.0
1239
+ rope_scale: float = 1.0
1240
+
1241
+ def __post_init__(self):
1242
+ assert self.mode in [None, "rope"], f"Invalid mode: {self.mode}."
1243
+ if self.mode == "rope":
1244
+ assert self.rope_theta > 0, f"Invalid rope theta: {self.rope_theta}."
1245
+
1246
+
1247
+ @dataclass
1248
+ class CacheConfig:
1249
+ """A dataclass to hold information how to configure the cache."""
1250
+
1251
+ dtype: Optional[torch.dtype] = None
1252
+
1253
+
1254
+ @dataclass
1255
+ class AttentionInfo:
1256
+ """Information about the attention op.
1257
+
1258
+ This is the dataclass collected by the kvcache transformation and passed in to the
1259
+ AttentionDescriptor methods to inform the attention op about the attention configuration.
1260
+ """
1261
+
1262
+ num_heads: int
1263
+ num_kv_heads: int
1264
+ head_dim: int # embedding size of each head
1265
+ dtype: torch.dtype
1266
+
1267
+ cache_config: CacheConfig
1268
+ pos_embd_config: PositionalEmbeddingConfig
1269
+ # rope_dim represents embedding size of decoupled q/k that carry rope information
1270
+ # when rope_dim != 0 the decoupled q/k tensor carrying rope information is the last part of the tensor [-rope_dim: ]
1271
+ rope_dim: Optional[int] = 0
1272
+
1273
+
1274
+ @dataclass
1275
+ class SequenceInfo:
1276
+ """A dataclass to hold information about how the sequence is laid out and stored in cache.
1277
+
1278
+ We assume the sequence + cache is laid out in the following way:
1279
+
1280
+ - input_ids: [id_0, ..., id_{s_total-1}]
1281
+ flattened sequence of [b, 1] or [1, s_total]. We use [b, 1] to denote generate-only batches.
1282
+ - seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i)
1283
+ Describes how long each sequence is. For example,
1284
+ input_ids[:s_0] will correspond to sequence 0 in the batch and input_ids[s_0:s_1] will
1285
+ correspond to sequence 1 in the batch.
1286
+ - input_pos: [pos_0, ..., pos_{b-1}]
1287
+ Corresponds to the total number of tokens that has been already been cached for each sequence
1288
+ in the batch.
1289
+ - cache_loc: [c0, ...., c_{np-1}] where np is total number of pages allocated to describe all
1290
+ sequences in the batch.
1291
+ - pages_per_seq: [ps_0, ps_1, ..., ps_{b-1}] where ps_i is the number of pages allocated for
1292
+ sequence i. Note that, for example, cache_loc[p_0:p_1] will correspond to the pages associated
1293
+ with sequence 1 in the batch.
1294
+
1295
+ Here are a couple of notes to emphasize this notation:
1296
+
1297
+ - The total number of allocated token space for sequence i is given by ps_i * page_size. This is
1298
+ the total number of tokens that can be cached for each sequence.
1299
+
1300
+ - NOTE: It must hold that pos_i + s_i <= ps_i * page_size for all i in [0, b-1]. Moreover, it is
1301
+ the responsibility of the cache manager and/or runtime to ensure sufficient page allocation
1302
+ for each sequence.
1303
+
1304
+ """
1305
+
1306
+ ## USE TO INITIALIZE DATA CLASS ###############################################################
1307
+ # max_seq_len corresponds the maximum number of tokens in any sequence. It includes the tokens in the
1308
+ # input sequence and the tokens generated by the model.
1309
+ max_seq_len: int = 1
1310
+ # max_batch_size corresponds to the maximum number of sequences (or requests) that the model can process.
1311
+ max_batch_size: int = 1
1312
+ # page_size is the granularity with which the cache pages are allocated for a paged kv cache.
1313
+ # For an unpaged cache, the page size should be set to max_seq_len.
1314
+ # Also note that two sequences in a batch can not share a page.
1315
+ page_size: int = 0
1316
+ # max_num_tokens is the maximum number of tokens that the model can process across all sequences in the batch.
1317
+ # If a batch is composed of context-only requests of input sequence length ISL,
1318
+ # then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens // ISL).
1319
+ # Similarly, if a batch is composed of generate-only requests,
1320
+ # then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens).
1321
+ max_num_tokens: int = 0
1322
+
1323
+ ## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP #################
1324
+ # input_ids MUST ALWAYS BE THE FIRST FIELD
1325
+ input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.int))
1326
+ seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int))
1327
+ input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int))
1328
+ cache_loc: torch.Tensor = field(default_factory=lambda: torch.arange(1, dtype=torch.int))
1329
+ pages_per_seq: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int))
1330
+ ################################################################################################
1331
+
1332
+ ## PRIVATE FIELDS ##############################################################################
1333
+ _sequence_lengths: List[int] = field(default_factory=list)
1334
+ _num_pages: int = 1
1335
+
1336
+ def __post_init__(self):
1337
+ if self.page_size < 1:
1338
+ self.page_size = self.max_seq_len
1339
+ if self.max_num_tokens < 1:
1340
+ self.max_num_tokens = self.max_batch_size * self.max_seq_len
1341
+ # if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
1342
+ # we use the provided max_num_tokens to calculate the number of pages
1343
+ total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len)
1344
+ self._num_pages = (total_tokens) // self.page_size + (total_tokens % self.page_size > 0)
1345
+ self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int)
1346
+ self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
1347
+ self.input_pos = torch.empty_like(self.seq_len)
1348
+ self.cache_loc = torch.empty(self.num_pages, dtype=torch.int)
1349
+ self.pages_per_seq = torch.empty_like(self.seq_len)
1350
+
1351
+ # dynamic shape descriptors for tensor args
1352
+ self._dynamic_shapes: Optional[Tuple[Dict[str, Dim]]] = None
1353
+
1354
+ # keep a list-like object of sequence lengths for simplicity as well
1355
+ self._sequence_lengths = [0] * self.max_batch_size
1356
+
1357
+ # call reset once to initialize the tensors
1358
+ self.reset()
1359
+
1360
+ @property
1361
+ def device(self) -> torch.device:
1362
+ return self.input_pos.device
1363
+
1364
+ @property
1365
+ def args(self) -> List[torch.Tensor]:
1366
+ args = []
1367
+ for f in fields(self):
1368
+ val = getattr(self, f.name)
1369
+ if isinstance(val, torch.Tensor):
1370
+ args.append(val)
1371
+ return args
1372
+
1373
+ @property
1374
+ def extra_arg_names(self) -> List[str]:
1375
+ """Return extra arg names for the prepare_metadata op beyond input_ids."""
1376
+ return [f.name for f in fields(self) if isinstance(getattr(self, f.name), torch.Tensor)][1:]
1377
+
1378
+ @property
1379
+ def dynamic_shapes(self) -> Tuple[Dict[str, Dim]]:
1380
+ """Return dynamic shapes of sequence info tensors.
1381
+
1382
+ NOTE: will be lazily initialized since the Dim object is not picklable for multi-processing.
1383
+ """
1384
+ if self._dynamic_shapes is None:
1385
+ dynamic_shapes = ({},)
1386
+ if self.max_batch_size > 1:
1387
+ dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size)
1388
+ dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len)
1389
+ dynamic_shapes += ({},) * len(self.extra_arg_names)
1390
+ self._dynamic_shapes = dynamic_shapes
1391
+ return self._dynamic_shapes
1392
+
1393
+ @property
1394
+ def num_sequences(self) -> int:
1395
+ return len(self._sequence_lengths)
1396
+
1397
+ @property
1398
+ def sequence_lengths(self) -> List[int]:
1399
+ return self._sequence_lengths
1400
+
1401
+ @property
1402
+ def input_positions(self) -> List[int]:
1403
+ return self.input_pos[: self.num_sequences].tolist()
1404
+
1405
+ @property
1406
+ def is_generate(self) -> bool:
1407
+ return all(sl == 1 for sl in self.sequence_lengths)
1408
+
1409
+ @property
1410
+ def num_pages(self) -> int:
1411
+ return self._num_pages
1412
+
1413
+ @num_pages.setter
1414
+ def num_pages(self, value):
1415
+ self._num_pages = value
1416
+ # update the cache_loc tensor
1417
+ self.cache_loc.resize_(value)
1418
+
1419
+ @property
1420
+ def is_paged(self) -> bool:
1421
+ return self.page_size < self.max_seq_len
1422
+
1423
+ @property
1424
+ def page_assignments(self) -> List[List[int]]:
1425
+ """Return the page assignments for each sequence."""
1426
+ pages_per_seq = self.pages_per_seq[: self.num_sequences].tolist()
1427
+ return [
1428
+ c_loc_one_seq.tolist()
1429
+ for c_loc_one_seq in torch.split(self.cache_loc[: sum(pages_per_seq)], pages_per_seq)
1430
+ ]
1431
+
1432
+ @classmethod
1433
+ def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
1434
+ """Sanitize sequence lengths.
1435
+
1436
+ We want to cover the following scenarios with this function:
1437
+
1438
+ 1. Pre-fill:
1439
+ input_ids: [1, s_total, ...]
1440
+ seq_len: [s_0, s_1, ..., s_{b-1}, 0, 0, ..., 0]
1441
+ ---> returns [s_0, s_1, ..., s_{b-1}]
1442
+ 2. Decode:
1443
+ input_ids: [b, 1, ...]
1444
+ seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
1445
+ |---- b ----|--- (max_batch_size - b) ---|
1446
+ --> returns [1,] * b
1447
+ 3. Decode in Cudagraph:
1448
+ input_ids: [b_cudagraph, 1, ...]
1449
+ seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
1450
+ |---- b ----|--- (max_batch_size - b) ---|
1451
+
1452
+ --> returns [1,] * b_cudagraph
1453
+ Here b <= b_cudagraph. We want to make sure that the seq_len is one-padded to
1454
+ b_cudagraph.
1455
+
1456
+ # TODO: I could see one possible issue with this approach in the future.
1457
+ # If we have b < b_cudagraph we now one-pad. However, we don't pad the cache location
1458
+ # information. What could happen is that the for the padded sequences the cache location
1459
+ # tensors point to allocated pages. This could lead to a situation where we write into
1460
+ # allocated cache pages polluting the cache of other sequences. Now this is not an issue
1461
+ # if we write the dummy sequences into unallocated cache pages... One fix could be to
1462
+ # pad not only the seq len but also pad the cache locations by just repeating the last
1463
+ # valid cache location in the batch. This would ensure that the dummy sequences just
1464
+ # repeats valid computation...
1465
+ """
1466
+ _, s = input_ids.shape[:2]
1467
+ num_seq = cls._get_sanitized_num_sequences(input_ids, seq_len)
1468
+ if s > 1:
1469
+ return seq_len[:num_seq].detach().clone()
1470
+ else:
1471
+ return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)
1472
+
1473
+ @staticmethod
1474
+ def _get_sanitized_num_sequences(input_ids: torch.Tensor, seq_len: torch.Tensor) -> int:
1475
+ """Get number of sequences.
1476
+
1477
+ We makes sure that this function is compatible with both torch graph capture and cudagraph.
1478
+ Both can be a bit temparamental when trying to extract the number of sequences from a tensor
1479
+ with max_batch_size or max_batch_size*max_seq_len.
1480
+ """
1481
+ b, s = input_ids.shape[:2]
1482
+ if s > 1:
1483
+ num_seq = torch.sum(seq_len > 0)
1484
+ assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
1485
+ else:
1486
+ num_seq = b
1487
+ return num_seq
1488
+
1489
+ def to(self, *args, **kwargs) -> None:
1490
+ for f in fields(self):
1491
+ val = getattr(self, f.name)
1492
+ if isinstance(val, torch.Tensor):
1493
+ setattr(self, f.name, val.to(*args, **kwargs))
1494
+
1495
+ def sync(self, other: "SequenceInfo") -> None:
1496
+ for f in fields(self):
1497
+ val = getattr(self, f.name)
1498
+ val_other = getattr(other, f.name)
1499
+ if f.name == "input_ids":
1500
+ setattr(self, f.name, val_other.to(self.device))
1501
+ elif f.name == "_sequence_lengths":
1502
+ self._sequence_lengths = val_other
1503
+ elif isinstance(val, torch.Tensor):
1504
+ val[: len(val_other)] = val_other.to(self.device)
1505
+ else:
1506
+ assert val == val_other, f"Field {f.name} mismatch: {val} != {val_other}."
1507
+
1508
+ def reset(self) -> None:
1509
+ """Reset the sequence information.
1510
+
1511
+ After reset the sequence information should correspond to a "generate-only" batch of
1512
+ sequences (b, s==1) without cache history.
1513
+ """
1514
+ # set a dummy sequence corresponding to a generate-only batch
1515
+ self.nest_sequences(torch.zeros(self.max_batch_size, 1, dtype=torch.int))
1516
+
1517
+ # reset cache information
1518
+ self.input_pos.zero_()
1519
+ self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device)
1520
+ self.pages_per_seq.fill_(1)
1521
+
1522
+ def _set_example_sequence(self) -> None:
1523
+ """Set an example sequence for export purposes."""
1524
+ self.reset()
1525
+ input_ids = torch.ones(
1526
+ min(2, self.max_batch_size),
1527
+ min(4, self.max_seq_len),
1528
+ dtype=torch.int,
1529
+ device=self.device,
1530
+ )
1531
+ self.nest_sequences(input_ids)
1532
+ self.input_ids = input_ids
1533
+
1534
+ def _set_max_num_tokens_sample(self) -> None:
1535
+ """Set an example sequence with max_num_tokens."""
1536
+ self.reset()
1537
+ seq_len = self.max_num_tokens // self.max_batch_size
1538
+ input_ids = torch.ones(
1539
+ self.max_batch_size,
1540
+ seq_len,
1541
+ dtype=torch.int,
1542
+ device=self.device,
1543
+ )
1544
+ self.pages_per_seq.fill_(seq_len // self.page_size)
1545
+ self.nest_sequences(input_ids)
1546
+
1547
+ def _set_generate_only_batch(self) -> None:
1548
+ """Set an example sequence for generate-only batch."""
1549
+ self.reset()
1550
+ self.nest_sequences([[1]] * self.max_batch_size)
1551
+
1552
+ def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None:
1553
+ """Create and store a flattened list of input_ids from the provided list of sequences.
1554
+
1555
+ This i/f will also update any relevant sequence information.
1556
+ """
1557
+ # set new sequence lengths
1558
+ seq_lens = [len(ids) for ids in input_ids]
1559
+ self.seq_len.zero_()
1560
+ self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True)
1561
+
1562
+ # set new input_ids as new tensor from flattened input_ids
1563
+ ids_tnsr_list = [
1564
+ lst.detach() if isinstance(lst, torch.Tensor) else torch.tensor(lst, dtype=torch.int)
1565
+ for lst in input_ids
1566
+ ]
1567
+ self.input_ids = torch.cat(ids_tnsr_list, dim=0).to(self.device)
1568
+
1569
+ # set derivative properties
1570
+ self._sequence_lengths = seq_lens
1571
+
1572
+ # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
1573
+ if self.is_generate:
1574
+ self.input_ids = self.input_ids.view(-1, 1, *self.input_ids.shape[1:])
1575
+ else:
1576
+ self.input_ids = self.input_ids.view(1, -1, *self.input_ids.shape[1:])
1577
+
1578
+ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
1579
+ t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
1580
+ return list(torch.split(t_squeezed, self.sequence_lengths))
1581
+
1582
+ def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None:
1583
+ """Update the starting position for each sequence in the cache.
1584
+
1585
+ If ``reset=True`, ``input_pos`` will be reset to zero before updating.
1586
+ """
1587
+ if not isinstance(seq_len, torch.Tensor):
1588
+ seq_len = torch.tensor(seq_len, dtype=torch.int)
1589
+ bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size
1590
+
1591
+ if reset:
1592
+ self.input_pos[:bs] = seq_len.to(self.device)
1593
+ else:
1594
+ self.input_pos[:bs] += seq_len.to(self.device)
1595
+
1596
+ def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None:
1597
+ """Set the cache location and pages_per_seq tensors from page assignments."""
1598
+ cache_loc_flat = torch.tensor(
1599
+ [p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int
1600
+ )
1601
+ self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True)
1602
+
1603
+ pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int)
1604
+ self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True)
1605
+
1606
+
1607
+ Constant = Union[int, float, str, None]
1608
+
1609
+
1610
+ class MHACallable(Protocol):
1611
+ def __call__(
1612
+ self,
1613
+ *qkv_metadata_and_caches: Union[torch.Tensor, Constant],
1614
+ ) -> torch.Tensor: ...
1615
+
1616
+
1617
+ class PrepareMetadataCallable(Protocol):
1618
+ def __call__(
1619
+ self,
1620
+ input_ids: torch.Tensor,
1621
+ seq_len: torch.Tensor,
1622
+ input_pos: torch.Tensor,
1623
+ cache_loc: torch.Tensor,
1624
+ pages_per_seq: torch.Tensor,
1625
+ page_size: int,
1626
+ ) -> List[torch.Tensor]: ...
1627
+
1628
+
1629
+ class GetCacheCallable(Protocol):
1630
+ def __call__(self, sequence_info: SequenceInfo) -> torch.Tensor: ...
1631
+
1632
+
1633
+ class GetBufferCallable(GetCacheCallable):
1634
+ pass
1635
+
1636
+
1637
+ class GetAttentionInfo(Protocol):
1638
+ def __call__() -> AttentionInfo: ...
1639
+
1640
+
1641
+ CacheInitializerDict = Dict[str, GetCacheCallable]
1642
+ BufferInitializerDict = Dict[str, GetBufferCallable]
1643
+
1644
+
1645
+ class AttentionDescriptor(ABC):
1646
+ """An interface to define a functional attention operator.
1647
+
1648
+ The main logic is contained with the actual attention op as well as the prepare_metadata op. The
1649
+ prepare_metadata op is responsible for converting the standardized sequence info into metadata
1650
+ specific to the attention op.
1651
+ """
1652
+
1653
+ @classmethod
1654
+ @abstractmethod
1655
+ def is_paged(cls) -> bool:
1656
+ """Return if the attention op is paged or not."""
1657
+
1658
+ @classmethod
1659
+ def get_attention_op(cls) -> Tuple[MHACallable, int]:
1660
+ """Get the attention op and the number of arguments corresponding to qkv.
1661
+
1662
+ The attention_op should follow the below signature:
1663
+
1664
+ ```
1665
+ def attention_op(
1666
+ *qkv, # list of tensors corresponding to Q, K, V as in original op
1667
+ *metadata, # global info about the sequences as returned by the prepare_metadata op
1668
+ *caches, # contains layer-specific caches per provided cache initializers
1669
+ *buffers, # global buffers used by the attention op as provided by buffer initializers
1670
+ *constants, # basic arguments (int, float, str, None) added as CONSTANTS in the graph
1671
+ ) -> torch.Tensor: ...
1672
+ ```
1673
+
1674
+ **Note that the attention op should be a valid torch custom op, which comes with
1675
+ restrictions on the supported types in the signature.**
1676
+
1677
+ **Note that the `qkv` tuple should be consistent across both the cached attention
1678
+ op and the op that it is replacing.**
1679
+
1680
+ """
1681
+ raise NotImplementedError
1682
+
1683
+ @classmethod
1684
+ @abstractmethod
1685
+ def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
1686
+ """Get the prepare_metadata op.
1687
+
1688
+ The prepare_metadata op should follow the below signature:
1689
+
1690
+ ```
1691
+ def prepare_metadata(
1692
+ input_ids: torch.Tensor,
1693
+ seq_len: torch.Tensor,
1694
+ input_pos: torch.Tensor,
1695
+ cache_loc: torch.Tensor,
1696
+ ) -> List[torch.Tensor]: ...
1697
+ ```
1698
+ The metadata should contain all necessary global information required for the underlying
1699
+ attention op to process the input sequence and the returned list of tensors will be passed
1700
+ on to each invocation of the attention op in the graph.
1701
+
1702
+ prepare_metadata is called once at the beginning of the forward pass.
1703
+
1704
+ **Note that the prepare_metadata op should be a valid torch custom op, which comes with
1705
+ restrictions on the supported types in the signature.**
1706
+ """
1707
+ return NotImplementedError
1708
+
1709
+ @classmethod
1710
+ @abstractmethod
1711
+ def get_cache_initializers(cls, get_info: GetAttentionInfo) -> CacheInitializerDict:
1712
+ """Provide a dictionary of function pointers that can be used to initialize the caches.
1713
+
1714
+ The key corresponds to the argument name used in the attention op signature. The function
1715
+ key doesn't need to be unique across multiple attention nodes in the graph. The key used to
1716
+ describe the cache in the graph will be patched with the attention node index to ensure
1717
+ uniqueness.
1718
+
1719
+ ``get_cache_initializers`` will be called *once* after the model initialization and before
1720
+ the initial forward pass for each attention op detected in the graph. The caches will be
1721
+ managed by the global CacheManager and passed back to the attention op during the forward
1722
+ pass.
1723
+
1724
+ If the cache initializer requires information about the attention op, the ``get_info``
1725
+ function can be called **inside** the cache initializer to retrieve the necessary
1726
+ information.
1727
+ """
1728
+ raise NotImplementedError
1729
+
1730
+ @classmethod
1731
+ def get_global_buffer_initializers(cls, get_info: GetAttentionInfo) -> BufferInitializerDict:
1732
+ """Provide a dictionary of function pointers that can be used to initialize buffers.
1733
+
1734
+ The key corresponds to the buffer name used in the graph module and will **not**
1735
+ be patched unlike a cache key. Hence, it is a **global** key that is shared across all
1736
+ attention ops in the model much like a regular buffer in an nn.Module. That means if this
1737
+ i/f is called for multiple attention ops, the same buffer will be shared across all of them
1738
+ if this function provides the same key multiple times.
1739
+
1740
+ Buffers are initialize *once* after the model initialization and before the initial forward
1741
+ pass for each attention op detected in the graph. The buffer will be managed by the global
1742
+ CacheManager and passed back to the attention op during the forward pass.
1743
+
1744
+ If the buffer initializer requires information about the attention op, the ``get_info``
1745
+ function can be called **inside** the buffer initializer to retrieve the necessary
1746
+ information.
1747
+ """
1748
+ return {}
1749
+
1750
+ @classmethod
1751
+ def get_constants(cls, attention_info: AttentionInfo) -> List[Constant]:
1752
+ """Provide a list of constant arguments to be passed to the attention op.
1753
+
1754
+ The constant arguments are passed to the attention op as additional arguments after the
1755
+ caches and buffers. The constants are expected to be of type int, float, str, or None.
1756
+ """
1757
+ return []
1758
+
1759
+
1760
+ class AttentionRegistry:
1761
+ """A simple registry to look up different attention implementations."""
1762
+
1763
+ _attention_registry: Dict[str, Type["AttentionDescriptor"]] = {}
1764
+
1765
+ @classmethod
1766
+ def register(cls, kernel_source: str) -> Type["AttentionDescriptor"]:
1767
+ def decorator(attention_cls: Type["AttentionDescriptor"]):
1768
+ assert kernel_source not in cls._attention_registry, (
1769
+ f"Attention source {kernel_source} already registered."
1770
+ )
1771
+ cls._attention_registry[kernel_source] = attention_cls
1772
+ return attention_cls
1773
+
1774
+ return decorator
1775
+
1776
+ @classmethod
1777
+ def get(cls, kernel_source: str) -> Type["AttentionDescriptor"]:
1778
+ assert cls.has(kernel_source), f"Attention source {kernel_source} not registered."
1779
+ return cls._attention_registry[kernel_source]
1780
+
1781
+ @classmethod
1782
+ def has(cls, kernel_source: str) -> bool:
1783
+ return kernel_source in cls._attention_registry
1784
+
1785
+
1786
+
1787
+ @torch.library.custom_op("attention::scaled_dot_product_attention", mutates_args=())
1788
+ def scaled_dot_product_attention(
1789
+ query: torch.Tensor,
1790
+ key: torch.Tensor,
1791
+ value: torch.Tensor,
1792
+ attn_mask: Optional[torch.Tensor] = None,
1793
+ dropout_p: float = 0.0,
1794
+ is_causal: bool = False,
1795
+ scale: Optional[float] = None,
1796
+ ) -> torch.Tensor:
1797
+ """A carbon copy of torch.nn.functional.scaled_dot_product_attention as custom op.
1798
+
1799
+ Using this custom op instead of using the functional directly ensures consistent representation
1800
+ of the vanilla sdpa in a graph.
1801
+ """
1802
+ return F.scaled_dot_product_attention(
1803
+ query,
1804
+ key,
1805
+ value,
1806
+ attn_mask=attn_mask,
1807
+ dropout_p=dropout_p,
1808
+ is_causal=is_causal,
1809
+ scale=scale,
1810
+ )
1811
+
1812
+
1813
+ @scaled_dot_product_attention.register_fake
1814
+ def scaled_dot_product_attention_fake(
1815
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
1816
+ ):
1817
+ """Fake implementation of scaled_dot_product_attention."""
1818
+ return torch.empty_like(query)
1819
+
1820
+
1821
+ def _generate_mha(
1822
+ q: torch.Tensor,
1823
+ k: torch.Tensor,
1824
+ v: torch.Tensor,
1825
+ k_cache: torch.Tensor,
1826
+ v_cache: torch.Tensor,
1827
+ cache_locs: torch.Tensor,
1828
+ input_pos: torch.Tensor,
1829
+ out: torch.Tensor,
1830
+ ):
1831
+ b, (n_heads, q_d_head) = q.shape[0], q.shape[-2:]
1832
+ max_seq_len, n_kv_heads = k_cache.shape[1:3]
1833
+ v_d_head = v.shape[-1]
1834
+ device = q.device
1835
+
1836
+ HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
1837
+ SEQ_BLOCK_SIZE = 256
1838
+ num_blocks = (max_seq_len + SEQ_BLOCK_SIZE - 1) // SEQ_BLOCK_SIZE
1839
+
1840
+ stage1_output_values = torch.empty(
1841
+ b, n_heads, num_blocks, v_d_head, device=device, dtype=torch.float32
1842
+ )
1843
+ stage1_output_logsumexp = torch.empty(
1844
+ b, n_heads, num_blocks, device=device, dtype=torch.float32
1845
+ ) - float("inf")
1846
+
1847
+ (
1848
+ update_kv_cache[(b, n_kv_heads, 1)](
1849
+ k,
1850
+ v,
1851
+ None,
1852
+ None,
1853
+ k_cache,
1854
+ v_cache,
1855
+ input_pos,
1856
+ cache_locs,
1857
+ max_seq_len,
1858
+ n_kv_heads,
1859
+ q_d_head,
1860
+ v_d_head,
1861
+ 1,
1862
+ GENERATE_ONLY=True,
1863
+ ),
1864
+ )
1865
+
1866
+ gqa_attention_kv_stage1[
1867
+ (
1868
+ b,
1869
+ n_kv_heads,
1870
+ num_blocks,
1871
+ )
1872
+ ](
1873
+ q,
1874
+ k_cache,
1875
+ v_cache,
1876
+ cache_locs,
1877
+ input_pos,
1878
+ stage1_output_values,
1879
+ stage1_output_logsumexp,
1880
+ num_blocks,
1881
+ max_seq_len,
1882
+ n_heads,
1883
+ n_kv_heads,
1884
+ q_d_head,
1885
+ v_d_head,
1886
+ SEQ_BLOCK_SIZE,
1887
+ HEAD_BLOCK_SIZE,
1888
+ )
1889
+ attention_kv_stage2[(b, n_heads, 1)](
1890
+ stage1_output_values,
1891
+ stage1_output_logsumexp,
1892
+ out,
1893
+ input_pos,
1894
+ num_blocks,
1895
+ n_heads,
1896
+ v_d_head,
1897
+ SEQ_BLOCK_SIZE,
1898
+ )
1899
+
1900
+
1901
+ def _context_mha(
1902
+ q: torch.Tensor,
1903
+ k: torch.Tensor,
1904
+ v: torch.Tensor,
1905
+ k_cache: torch.Tensor,
1906
+ v_cache: torch.Tensor,
1907
+ out: torch.Tensor,
1908
+ ):
1909
+ b, s, n_heads, q_d_head = q.shape
1910
+ max_seq_len, n_kv_heads = k_cache.shape[1:3]
1911
+ v_d_head = v.shape[-1]
1912
+
1913
+ SEQ_BLOCK = 128
1914
+ softmax_scale = 1.0 / math.sqrt(q_d_head)
1915
+ grid = (b, n_heads, (s + SEQ_BLOCK - 1) // SEQ_BLOCK)
1916
+ context_attention_kv[grid](
1917
+ q,
1918
+ k,
1919
+ v,
1920
+ k_cache,
1921
+ v_cache,
1922
+ s,
1923
+ out,
1924
+ softmax_scale,
1925
+ n_heads,
1926
+ n_kv_heads,
1927
+ q_d_head,
1928
+ v_d_head,
1929
+ SEQ_BLOCK,
1930
+ max_seq_len,
1931
+ num_stages=2,
1932
+ )
1933
+
1934
+
1935
+ @torch.library.custom_op("attention::fused_mha_with_cache", mutates_args=())
1936
+ def fused_mha_with_cache(
1937
+ q: torch.Tensor,
1938
+ k: torch.Tensor,
1939
+ v: torch.Tensor,
1940
+ input_pos: torch.Tensor,
1941
+ k_cache: torch.Tensor,
1942
+ v_cache: torch.Tensor,
1943
+ freqs_cis: Optional[torch.Tensor],
1944
+ ) -> torch.Tensor:
1945
+ """Fused MHA with cache that takes raw input from q, k, v GEMMs."""
1946
+ # b, s info
1947
+ b, s = q.shape[:2]
1948
+ head_dim = k_cache.shape[-1]
1949
+
1950
+ # reshapes with num_heads and head_dim
1951
+ q = q.view(b, s, -1, head_dim)
1952
+ k = k.view(b, s, -1, head_dim)
1953
+ v = v.view(b, s, -1, head_dim)
1954
+
1955
+ # rope embedding
1956
+ if freqs_cis is not None:
1957
+ q = torch.ops.rope.apply_rope_with_input_pos(q, freqs_cis, input_pos, "bsnd")
1958
+ k = torch.ops.rope.apply_rope_with_input_pos(k, freqs_cis, input_pos, "bsnd")
1959
+
1960
+ # attention (assumed layout is bsnd)
1961
+ y = torch.empty_like(q)
1962
+ if s > 1:
1963
+ # context phase
1964
+ _context_mha(q, k, v, k_cache, v_cache, y)
1965
+ else:
1966
+ # generate phase
1967
+ cache_locs = torch.arange(0, b, device=q.device, dtype=torch.int32)
1968
+ _generate_mha(q, k, v, k_cache, v_cache, cache_locs, input_pos, y)
1969
+
1970
+ return y.view(b, s, -1) # [b,s,n*h_d]
1971
+
1972
+
1973
+ @fused_mha_with_cache.register_fake
1974
+ def fused_mha_fake(
1975
+ q: torch.Tensor,
1976
+ k: torch.Tensor,
1977
+ v: torch.Tensor,
1978
+ input_pos: torch.Tensor,
1979
+ k_cache: torch.Tensor,
1980
+ v_cache: torch.Tensor,
1981
+ freqs_cis: torch.Tensor,
1982
+ ):
1983
+ return torch.empty_like(q.contiguous())
1984
+
1985
+
1986
+ def _flattened_context_mha(
1987
+ q: torch.Tensor,
1988
+ k: torch.Tensor,
1989
+ v: torch.Tensor,
1990
+ input_pos: torch.Tensor,
1991
+ cache_loc: torch.Tensor,
1992
+ k_cache: torch.Tensor,
1993
+ v_cache: torch.Tensor,
1994
+ seq_len: torch.Tensor,
1995
+ seq_start: torch.Tensor,
1996
+ out: torch.Tensor,
1997
+ ) -> None:
1998
+ # NOTE: s_total == sum(seq_len)
1999
+ s_total, n_heads, q_d_head = q.shape
2000
+ max_cache_seq_len, n_kv_heads = k_cache.shape[1:3]
2001
+ v_d_head = v.shape[-1]
2002
+ BATCH_SIZE: int = len(input_pos)
2003
+ SEQ_BLOCK = 32
2004
+ (
2005
+ update_kv_cache[(BATCH_SIZE, n_kv_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)](
2006
+ k,
2007
+ v,
2008
+ seq_len,
2009
+ seq_start,
2010
+ k_cache,
2011
+ v_cache,
2012
+ input_pos,
2013
+ cache_loc,
2014
+ max_cache_seq_len,
2015
+ n_kv_heads,
2016
+ q_d_head,
2017
+ v_d_head,
2018
+ 32,
2019
+ GENERATE_ONLY=False,
2020
+ ),
2021
+ )
2022
+ # TODO: use input_pos to get the correct cache locations
2023
+ softmax_scale = 1.0 / math.sqrt(q_d_head)
2024
+ grid = (BATCH_SIZE, n_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
2025
+ context_attention_kv_flattened[grid](
2026
+ q,
2027
+ seq_len,
2028
+ seq_start,
2029
+ k_cache,
2030
+ v_cache,
2031
+ input_pos,
2032
+ cache_loc,
2033
+ out,
2034
+ softmax_scale,
2035
+ n_heads,
2036
+ n_kv_heads,
2037
+ q_d_head,
2038
+ v_d_head,
2039
+ SEQ_BLOCK,
2040
+ max_cache_seq_len,
2041
+ num_stages=2,
2042
+ )
2043
+
2044
+
2045
+ @torch.library.custom_op("attention::fused_flattened_mha_with_cache", mutates_args=())
2046
+ def fused_flattened_mha_with_cache(
2047
+ # Q, K, V
2048
+ q: torch.Tensor,
2049
+ k: torch.Tensor,
2050
+ v: torch.Tensor,
2051
+ # METADATA
2052
+ seq_len: torch.Tensor,
2053
+ input_pos: torch.Tensor,
2054
+ cache_loc: torch.Tensor,
2055
+ seq_start: torch.Tensor,
2056
+ # CACHES
2057
+ k_cache: torch.Tensor,
2058
+ v_cache: torch.Tensor,
2059
+ # BUFFERS
2060
+ freqs_cis: torch.Tensor,
2061
+ # CONSTANTS
2062
+ # <none>
2063
+ ) -> torch.Tensor:
2064
+ """Flattened & fused MHA with cache that takes raw input from q, k, v GEMMs.
2065
+
2066
+ NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH.
2067
+ """
2068
+ # b, s info
2069
+ # NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences.
2070
+ # Generally speaking, we expect one of two cases here:
2071
+ # 1. b > 0, s==1: this indicates a generate-only batch of tokens.
2072
+ # 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences
2073
+ # and number of tokens per sequence are encoded in seq_len and seq_start.
2074
+ head_dim = k_cache.shape[-1]
2075
+ b, s, d = q.shape
2076
+
2077
+ # reshapes with num_heads and head_dim
2078
+ if s == 1:
2079
+ bs_view = (b, s)
2080
+ else:
2081
+ bs_view = (b * s,)
2082
+ q = q.view(*bs_view, q.shape[2] // head_dim, head_dim)
2083
+ k = k.view(*bs_view, k.shape[2] // head_dim, head_dim)
2084
+ v = v.view(*bs_view, v.shape[2] // head_dim, head_dim)
2085
+
2086
+ # rope embedding for generate-only or mixed
2087
+ if freqs_cis is not None and freqs_cis.numel() > 0:
2088
+ if s == 1:
2089
+ rope_args = (freqs_cis, input_pos, "bsnd")
2090
+ fn_rope = torch.ops.rope.apply_rope_with_input_pos
2091
+ else:
2092
+ rope_args = (freqs_cis, input_pos, seq_len, seq_start)
2093
+ fn_rope = torch.ops.rope.apply_rope_on_flattened_inputs
2094
+ q = fn_rope(q, *rope_args)
2095
+ k = fn_rope(k, *rope_args)
2096
+
2097
+ # run attention
2098
+ y = torch.empty_like(q)
2099
+ if s == 1:
2100
+ # generate-only phase
2101
+ _generate_mha(q, k, v, k_cache, v_cache, cache_loc, input_pos, y)
2102
+ else:
2103
+ # mixed context + generate phase
2104
+ _flattened_context_mha(
2105
+ q,
2106
+ k,
2107
+ v,
2108
+ input_pos,
2109
+ cache_loc,
2110
+ k_cache,
2111
+ v_cache,
2112
+ seq_len,
2113
+ seq_start,
2114
+ y,
2115
+ )
2116
+
2117
+ return y.view(b, s, d) # [b,s,n*h_d]
2118
+
2119
+
2120
+ @fused_flattened_mha_with_cache.register_fake
2121
+ def fused_flattened_mha_fake(
2122
+ q: torch.Tensor,
2123
+ k: torch.Tensor,
2124
+ v: torch.Tensor,
2125
+ seq_len: torch.Tensor,
2126
+ input_pos: torch.Tensor,
2127
+ cache_loc: torch.Tensor,
2128
+ seq_start: torch.Tensor,
2129
+ k_cache: torch.Tensor,
2130
+ v_cache: torch.Tensor,
2131
+ freqs_cis: torch.Tensor,
2132
+ ):
2133
+ return torch.empty_like(q.contiguous())
2134
+
2135
+
2136
+ def _generate_mha_rope_fusion(
2137
+ q: torch.Tensor,
2138
+ k: torch.Tensor,
2139
+ v: torch.Tensor,
2140
+ freqs_cis: torch.Tensor,
2141
+ k_cache: torch.Tensor,
2142
+ v_cache: torch.Tensor,
2143
+ cache_locs: torch.Tensor,
2144
+ input_pos: torch.Tensor,
2145
+ out: torch.Tensor,
2146
+ ):
2147
+ b, (n_heads, d_head) = q.shape[0], q.shape[-2:]
2148
+ max_seq_len, n_kv_heads = k_cache.shape[1:3]
2149
+ device = q.device
2150
+
2151
+ SEQ_BLOCK_SIZE = 64
2152
+ num_blocks = (max_seq_len + SEQ_BLOCK_SIZE - 1) // SEQ_BLOCK_SIZE
2153
+ stage1_output_values = torch.empty(
2154
+ b, n_heads, num_blocks, d_head, device=device, dtype=torch.float32
2155
+ )
2156
+ stage1_output_logsumexp = torch.empty(
2157
+ b, n_heads, num_blocks, device=device, dtype=torch.float32
2158
+ ) - float("inf")
2159
+ q_rope = torch.empty_like(q)
2160
+ HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
2161
+
2162
+ (
2163
+ update_kv_cache_rope_fusion[(b, n_kv_heads, 1)](
2164
+ q,
2165
+ k,
2166
+ v,
2167
+ None,
2168
+ None,
2169
+ q_rope,
2170
+ k_cache,
2171
+ v_cache,
2172
+ input_pos,
2173
+ cache_locs,
2174
+ freqs_cis,
2175
+ max_seq_len,
2176
+ n_heads,
2177
+ n_kv_heads,
2178
+ d_head,
2179
+ 1,
2180
+ HEAD_BLOCK_SIZE,
2181
+ GENERATE_ONLY=True,
2182
+ ),
2183
+ )
2184
+
2185
+ HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
2186
+ gqa_attention_kv_stage1[
2187
+ (
2188
+ b,
2189
+ n_kv_heads,
2190
+ num_blocks,
2191
+ )
2192
+ ](
2193
+ q_rope,
2194
+ k_cache,
2195
+ v_cache,
2196
+ cache_locs,
2197
+ input_pos,
2198
+ stage1_output_values,
2199
+ stage1_output_logsumexp,
2200
+ num_blocks,
2201
+ max_seq_len,
2202
+ n_heads,
2203
+ n_kv_heads,
2204
+ d_head,
2205
+ d_head,
2206
+ SEQ_BLOCK_SIZE,
2207
+ HEAD_BLOCK_SIZE,
2208
+ )
2209
+ attention_kv_stage2[(b, n_heads, 1)](
2210
+ stage1_output_values,
2211
+ stage1_output_logsumexp,
2212
+ out,
2213
+ input_pos,
2214
+ num_blocks,
2215
+ n_heads,
2216
+ d_head,
2217
+ SEQ_BLOCK_SIZE,
2218
+ )
2219
+
2220
+
2221
+ def _flattened_context_mha_rope_fusion(
2222
+ q: torch.Tensor,
2223
+ k: torch.Tensor,
2224
+ v: torch.Tensor,
2225
+ freqs_cis: torch.Tensor,
2226
+ input_pos: torch.Tensor,
2227
+ cache_loc: torch.Tensor,
2228
+ k_cache: torch.Tensor,
2229
+ v_cache: torch.Tensor,
2230
+ seq_len: torch.Tensor,
2231
+ seq_start: torch.Tensor,
2232
+ out: torch.Tensor,
2233
+ ) -> None:
2234
+ # NOTE: s_total == sum(seq_len)
2235
+ s_total, n_heads, d_head = q.shape
2236
+ max_cache_seq_len, n_kv_heads = k_cache.shape[1:3]
2237
+ BATCH_SIZE: int = len(input_pos)
2238
+ SEQ_BLOCK = 32
2239
+ q_rope = torch.empty_like(q)
2240
+ HEAD_BLOCK_SIZE = max(16, triton.next_power_of_2(n_heads // n_kv_heads))
2241
+ (
2242
+ update_kv_cache_rope_fusion[
2243
+ (BATCH_SIZE, n_kv_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
2244
+ ](
2245
+ q,
2246
+ k,
2247
+ v,
2248
+ seq_len,
2249
+ seq_start,
2250
+ q_rope,
2251
+ k_cache,
2252
+ v_cache,
2253
+ input_pos,
2254
+ cache_loc,
2255
+ freqs_cis,
2256
+ max_cache_seq_len,
2257
+ n_heads,
2258
+ n_kv_heads,
2259
+ d_head,
2260
+ 32,
2261
+ HEAD_BLOCK_SIZE,
2262
+ GENERATE_ONLY=False,
2263
+ ),
2264
+ )
2265
+ # TODO: use input_pos to get the correct cache locations
2266
+ softmax_scale = 1.0 / math.sqrt(d_head)
2267
+ grid = (BATCH_SIZE, n_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
2268
+ context_attention_kv_flattened[grid](
2269
+ q_rope,
2270
+ seq_len,
2271
+ seq_start,
2272
+ k_cache,
2273
+ v_cache,
2274
+ input_pos,
2275
+ cache_loc,
2276
+ out,
2277
+ softmax_scale,
2278
+ n_heads,
2279
+ n_kv_heads,
2280
+ d_head,
2281
+ d_head,
2282
+ SEQ_BLOCK,
2283
+ max_cache_seq_len,
2284
+ num_stages=2,
2285
+ )
2286
+
2287
+
2288
+ @torch.library.custom_op("attention::fused_flattened_mha_with_cache_rope_fusion", mutates_args=())
2289
+ def fused_flattened_mha_with_cache_rope_fusion(
2290
+ q: torch.Tensor,
2291
+ k: torch.Tensor,
2292
+ v: torch.Tensor,
2293
+ input_pos: torch.Tensor,
2294
+ cache_loc: torch.Tensor,
2295
+ seq_len: torch.Tensor,
2296
+ seq_start: torch.Tensor,
2297
+ k_cache: torch.Tensor,
2298
+ v_cache: torch.Tensor,
2299
+ freqs_cis: Optional[torch.Tensor],
2300
+ ) -> torch.Tensor:
2301
+ """Flattened & fused MHA with cache that takes raw input from q, k, v GEMMs.
2302
+
2303
+ Fuse k rope in update_kv_cache and q rope in attention.
2304
+ NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH.
2305
+ """
2306
+ # this function only handle requests with rope embadding.
2307
+ if freqs_cis is None:
2308
+ return fused_flattened_mha_with_cache(
2309
+ q,
2310
+ k,
2311
+ v,
2312
+ input_pos,
2313
+ cache_loc,
2314
+ seq_len,
2315
+ seq_start,
2316
+ k_cache,
2317
+ v_cache,
2318
+ freqs_cis,
2319
+ )
2320
+
2321
+ # b, s info
2322
+ # NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences.
2323
+ # Generally speaking, we expect one of two cases here:
2324
+ # 1. b > 0, s==1: this indicates a generate-only batch of tokens.
2325
+ # 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences
2326
+ # and number of tokens per sequence are encoded in seq_len and seq_start.
2327
+ b, s, d = q.shape
2328
+ head_dim = k_cache.shape[-1]
2329
+
2330
+ # reshapes with num_heads and head_dim
2331
+ if s == 1:
2332
+ bs_view = (b, s)
2333
+ else:
2334
+ bs_view = (b * s,)
2335
+ q = q.view(*bs_view, q.shape[2] // head_dim, head_dim)
2336
+ k = k.view(*bs_view, k.shape[2] // head_dim, head_dim)
2337
+ v = v.view(*bs_view, v.shape[2] // head_dim, head_dim)
2338
+
2339
+ # run attention
2340
+ y = torch.empty_like(q)
2341
+ if s == 1:
2342
+ # generate-only phase
2343
+ _generate_mha_rope_fusion(q, k, v, freqs_cis, k_cache, v_cache, cache_loc, input_pos, y)
2344
+ else:
2345
+ # mixed context + generate phase
2346
+ _flattened_context_mha_rope_fusion(
2347
+ q,
2348
+ k,
2349
+ v,
2350
+ freqs_cis,
2351
+ input_pos,
2352
+ cache_loc,
2353
+ k_cache,
2354
+ v_cache,
2355
+ seq_len,
2356
+ seq_start,
2357
+ y,
2358
+ )
2359
+
2360
+ return y.view(b, s, d) # [b,s,n*h_d]
2361
+
2362
+
2363
+ @fused_flattened_mha_with_cache_rope_fusion.register_fake
2364
+ def fused_flattened_mha_with_cache_rope_fusion_fake(
2365
+ q: torch.Tensor,
2366
+ k: torch.Tensor,
2367
+ v: torch.Tensor,
2368
+ input_pos: torch.Tensor,
2369
+ cache_loc: torch.Tensor,
2370
+ seq_len: torch.Tensor,
2371
+ seq_start: torch.Tensor,
2372
+ k_cache: torch.Tensor,
2373
+ v_cache: torch.Tensor,
2374
+ freqs_cis: torch.Tensor,
2375
+ ):
2376
+ return torch.empty_like(q.contiguous())
2377
+
2378
+
2379
+ def _paged_generate_mha(
2380
+ q: torch.Tensor,
2381
+ k: torch.Tensor,
2382
+ v: torch.Tensor,
2383
+ page_table: torch.Tensor,
2384
+ k_cache: torch.Tensor,
2385
+ v_cache: torch.Tensor,
2386
+ cache_loc: torch.Tensor,
2387
+ input_pos: torch.Tensor,
2388
+ out: torch.Tensor,
2389
+ max_seq_len: int,
2390
+ ):
2391
+ b, (n_heads, d_head) = q.shape[0], q.shape[-2:]
2392
+ PAGE_SIZE, n_kv_heads = k_cache.shape[1:3]
2393
+ device = q.device
2394
+
2395
+ SEQ_BLOCK_SIZE = PAGE_SIZE # 256
2396
+ num_blocks = (max_seq_len + SEQ_BLOCK_SIZE - 1) // SEQ_BLOCK_SIZE
2397
+ stage1_output_values = torch.empty(
2398
+ b, n_heads, num_blocks, d_head, device=device, dtype=torch.float32
2399
+ )
2400
+ stage1_output_logsumexp = torch.empty(
2401
+ b, n_heads, num_blocks, device=device, dtype=torch.float32
2402
+ ) - float("inf")
2403
+
2404
+ (
2405
+ update_paged_kv_cache[(b, n_kv_heads, 1)](
2406
+ k,
2407
+ v,
2408
+ None,
2409
+ None,
2410
+ k_cache,
2411
+ v_cache,
2412
+ cache_loc,
2413
+ input_pos,
2414
+ page_table,
2415
+ n_kv_heads,
2416
+ d_head,
2417
+ SEQ_BLOCK_SIZE,
2418
+ max_seq_len,
2419
+ PAGE_SIZE,
2420
+ page_table.stride(0),
2421
+ GENERATE_ONLY=True,
2422
+ ),
2423
+ )
2424
+
2425
+ attention_kv_paged_stage1[
2426
+ (
2427
+ b,
2428
+ n_heads,
2429
+ num_blocks,
2430
+ )
2431
+ ](
2432
+ q,
2433
+ k_cache,
2434
+ v_cache,
2435
+ cache_loc,
2436
+ page_table,
2437
+ input_pos,
2438
+ stage1_output_values,
2439
+ stage1_output_logsumexp,
2440
+ num_blocks,
2441
+ max_seq_len,
2442
+ n_heads,
2443
+ n_kv_heads,
2444
+ d_head,
2445
+ SEQ_BLOCK_SIZE,
2446
+ PAGE_SIZE,
2447
+ page_table.stride(0),
2448
+ )
2449
+ attention_kv_stage2[(b, n_heads, 1)](
2450
+ stage1_output_values,
2451
+ stage1_output_logsumexp,
2452
+ out,
2453
+ input_pos,
2454
+ num_blocks,
2455
+ n_heads,
2456
+ d_head,
2457
+ SEQ_BLOCK_SIZE,
2458
+ )
2459
+
2460
+
2461
+ def _paged_context_mha(
2462
+ q: torch.Tensor,
2463
+ k: torch.Tensor,
2464
+ v: torch.Tensor,
2465
+ input_pos: torch.Tensor,
2466
+ cache_loc: torch.Tensor,
2467
+ page_table: torch.Tensor,
2468
+ k_cache: torch.Tensor,
2469
+ v_cache: torch.Tensor,
2470
+ seq_len: torch.Tensor,
2471
+ seq_start: torch.Tensor,
2472
+ out: torch.Tensor,
2473
+ max_seq_len: int, # max cache length of sequence, kv_cache shape don't provide this info.
2474
+ ) -> None:
2475
+ # NOTE: s_total == sum(seq_len)
2476
+ s_total, n_heads, d_head = q.shape
2477
+ PAGE_SIZE, n_kv_heads = k_cache.shape[1:3]
2478
+ BATCH_SIZE = len(input_pos)
2479
+ SEQ_BLOCK = PAGE_SIZE # 32
2480
+ (
2481
+ update_paged_kv_cache[
2482
+ (BATCH_SIZE, n_kv_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
2483
+ ](
2484
+ k,
2485
+ v,
2486
+ seq_len,
2487
+ seq_start,
2488
+ k_cache,
2489
+ v_cache,
2490
+ cache_loc,
2491
+ input_pos,
2492
+ page_table,
2493
+ n_kv_heads,
2494
+ d_head,
2495
+ SEQ_BLOCK,
2496
+ max_seq_len,
2497
+ PAGE_SIZE,
2498
+ page_table.stride(0),
2499
+ GENERATE_ONLY=False,
2500
+ ),
2501
+ )
2502
+ softmax_scale = 1.0 / math.sqrt(d_head)
2503
+ grid = (BATCH_SIZE, n_heads, (max(seq_len) + SEQ_BLOCK - 1) // SEQ_BLOCK)
2504
+ context_attention_kv_paged[grid](
2505
+ q,
2506
+ seq_len,
2507
+ seq_start,
2508
+ k_cache,
2509
+ v_cache,
2510
+ cache_loc,
2511
+ input_pos,
2512
+ page_table,
2513
+ softmax_scale,
2514
+ out,
2515
+ n_heads,
2516
+ n_kv_heads,
2517
+ d_head,
2518
+ SEQ_BLOCK,
2519
+ max_seq_len,
2520
+ PAGE_SIZE,
2521
+ page_table.stride(0),
2522
+ num_stages=2,
2523
+ )
2524
+
2525
+
2526
+ @torch.library.custom_op("attention::fused_mha_with_paged_cache", mutates_args=())
2527
+ def fused_mha_with_paged_cache(
2528
+ q: torch.Tensor,
2529
+ k: torch.Tensor,
2530
+ v: torch.Tensor,
2531
+ input_pos: torch.Tensor,
2532
+ cache_loc: torch.Tensor,
2533
+ seq_len: torch.Tensor,
2534
+ seq_start: torch.Tensor,
2535
+ page_table: torch.Tensor,
2536
+ max_seq_len: int,
2537
+ k_cache: torch.Tensor,
2538
+ v_cache: torch.Tensor,
2539
+ freqs_cis: Optional[torch.Tensor],
2540
+ ) -> torch.Tensor:
2541
+ """Fused MHA with paged cache that takes raw input from q, k, v GEMMs.
2542
+
2543
+ NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH.
2544
+ """
2545
+ # b, s info
2546
+ # NOTE: b, s are just the shapes of the input tensor q; not necessarily the number of sequences.
2547
+ # Generally speaking, we expect one of two cases here:
2548
+ # 1. b > 0, s==1: this indicates a generate-only batch of tokens.
2549
+ # 2. b==1, s > 0: this indicates a mixed context+generate phase. The actual number of sequences
2550
+ # and number of tokens per sequence are encoded in seq_len and seq_start.
2551
+ # Assuming that context seq_len always > 0.
2552
+ b, s, d = q.shape
2553
+ head_dim = k_cache.shape[-1]
2554
+
2555
+ # reshapes with num_heads and head_dim
2556
+ if s == 1:
2557
+ bs_view = (b, s)
2558
+ else:
2559
+ bs_view = (b * s,)
2560
+ q = q.view(*bs_view, q.shape[2] // head_dim, head_dim)
2561
+ k = k.view(*bs_view, k.shape[2] // head_dim, head_dim)
2562
+ v = v.view(*bs_view, v.shape[2] // head_dim, head_dim)
2563
+
2564
+ # rope embedding for generate-only or mixed
2565
+ if freqs_cis is not None:
2566
+ if s == 1:
2567
+ rope_args = (freqs_cis, input_pos, "bsnd")
2568
+ fn_rope = torch.ops.rope.apply_rope_with_input_pos
2569
+ else:
2570
+ rope_args = (freqs_cis, input_pos, seq_len, seq_start)
2571
+ fn_rope = torch.ops.rope.apply_rope_on_flattened_inputs
2572
+ q = fn_rope(q, *rope_args)
2573
+ k = fn_rope(k, *rope_args)
2574
+
2575
+ # run attention
2576
+ y = torch.empty_like(q)
2577
+ if s == 1:
2578
+ # generate-only phase
2579
+ _paged_generate_mha(
2580
+ q, k, v, page_table, k_cache, v_cache, cache_loc, input_pos, y, max_seq_len
2581
+ )
2582
+ else:
2583
+ # mixed context + generate phase
2584
+ _paged_context_mha(
2585
+ q,
2586
+ k,
2587
+ v,
2588
+ input_pos,
2589
+ cache_loc,
2590
+ page_table,
2591
+ k_cache,
2592
+ v_cache,
2593
+ seq_len,
2594
+ seq_start,
2595
+ y,
2596
+ max_seq_len,
2597
+ )
2598
+
2599
+ return y.view(b, s, d) # [b,s,n*h_d]
2600
+
2601
+
2602
+ @fused_mha_with_paged_cache.register_fake
2603
+ def fused_mha_with_paged_cache_fake(
2604
+ q: torch.Tensor,
2605
+ k: torch.Tensor,
2606
+ v: torch.Tensor,
2607
+ input_pos: torch.Tensor,
2608
+ cache_loc: torch.Tensor,
2609
+ seq_len: torch.Tensor,
2610
+ seq_start: torch.Tensor,
2611
+ page_table: torch.Tensor,
2612
+ max_seq_len: int,
2613
+ k_cache: torch.Tensor,
2614
+ v_cache: torch.Tensor,
2615
+ freqs_cis: Optional[torch.Tensor],
2616
+ ) -> torch.Tensor:
2617
+ return torch.empty_like(q.contiguous())
2618
+
2619
+
2620
+ @torch.library.custom_op("attention::prepare_fused_mha_metadata", mutates_args=())
2621
+ def prepare_fused_mha_metadata(
2622
+ input_ids: torch.Tensor,
2623
+ seq_len: torch.Tensor,
2624
+ input_pos: torch.Tensor,
2625
+ cache_loc: torch.Tensor,
2626
+ pages_per_seq: torch.Tensor,
2627
+ page_size: int,
2628
+ ) -> List[torch.Tensor]:
2629
+ num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
2630
+ seq_start = torch.zeros_like(seq_len[:num_seq])
2631
+ seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0)
2632
+ return (
2633
+ seq_len[:num_seq].clone(),
2634
+ input_pos[:num_seq].clone(),
2635
+ cache_loc[:num_seq].clone(),
2636
+ seq_start,
2637
+ )
2638
+
2639
+
2640
+ @prepare_fused_mha_metadata.register_fake
2641
+ def prepare_fused_mha_metadata_fake(
2642
+ input_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
2643
+ ):
2644
+ return (
2645
+ torch.empty_like(seq_len),
2646
+ torch.empty_like(input_pos),
2647
+ torch.empty_like(cache_loc),
2648
+ torch.empty_like(seq_len),
2649
+ )
2650
+
2651
+
2652
+ @AttentionRegistry.register("TritonWithFlattenedInputs")
2653
+ class TritonWithFlattenedInputs(AttentionDescriptor):
2654
+ @classmethod
2655
+ def is_paged(cls):
2656
+ """Return if the attention op is paged or not."""
2657
+ return False
2658
+
2659
+ @classmethod
2660
+ def get_attention_op(cls):
2661
+ return torch.ops.attention.fused_flattened_mha_with_cache, 3
2662
+
2663
+ @classmethod
2664
+ def get_prepare_metadata_op(cls):
2665
+ return torch.ops.attention.prepare_fused_mha_metadata, 4
2666
+
2667
+ @classmethod
2668
+ def get_cache_initializers(cls, get_info):
2669
+ def _get_cache(si: SequenceInfo):
2670
+ assert not si.is_paged, "Paged cache not supported for TritonWithFlattenedInputs"
2671
+ attention_info = get_info()
2672
+ return torch.empty(
2673
+ si.num_pages,
2674
+ si.page_size,
2675
+ attention_info.num_kv_heads,
2676
+ attention_info.head_dim,
2677
+ device=si.device,
2678
+ dtype=attention_info.cache_config.dtype or attention_info.dtype,
2679
+ )
2680
+
2681
+ return {"k_cache": _get_cache, "v_cache": _get_cache}
2682
+
2683
+ @classmethod
2684
+ def get_global_buffer_initializers(cls, get_info):
2685
+ attention_info = get_info()
2686
+ head_dim = attention_info.head_dim
2687
+ pos_embd_config = attention_info.pos_embd_config
2688
+
2689
+ def _get_freqs_cis(si: SequenceInfo):
2690
+ if pos_embd_config.mode is None:
2691
+ return torch.empty(0, device=si.device)
2692
+ assert pos_embd_config.mode == "rope", f"Mode {pos_embd_config.mode=} not supported"
2693
+ assert pos_embd_config.rope_scale == 1.0, f"{pos_embd_config.rope_scale=} not supported"
2694
+ rope_theta = pos_embd_config.rope_theta
2695
+ return cls._precompute_freqs_cis(2 * si.max_seq_len, head_dim, rope_theta).to(si.device)
2696
+
2697
+ k_full = "_".join(map(str, ["freqs_cis", *astuple(pos_embd_config)])).replace(".", "_")
2698
+ return {k_full: _get_freqs_cis}
2699
+
2700
+ @staticmethod
2701
+ def _precompute_freqs_cis(
2702
+ seq_len: int, head_dim: int, rope_theta: Optional[float] = None
2703
+ ) -> torch.Tensor:
2704
+ if rope_theta is None:
2705
+ rope_theta = 1e4
2706
+ freqs = 1.0 / (
2707
+ rope_theta ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim)
2708
+ )
2709
+ t = torch.arange(seq_len)
2710
+ freqs = torch.outer(t, freqs)
2711
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
2712
+ # cos and sin (real and img) are packed
2713
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
2714
+ return cache.to(dtype=torch.float16)