zaydzuhri commited on
Commit
ad3ceb6
·
verified ·
1 Parent(s): b4ec538

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/models/delta_net/__pycache__/__init__.cpython-312.pyc +0 -0
  2. fla/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc +0 -0
  3. fla/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc +0 -0
  4. fla/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-312.pyc +0 -0
  5. fla/models/gated_deltanet/modeling_gated_deltanet.py +412 -0
  6. fla/models/gated_deltaproduct/__pycache__/__init__.cpython-312.pyc +0 -0
  7. fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc +0 -0
  8. fla/models/gla/__pycache__/configuration_gla.cpython-312.pyc +0 -0
  9. fla/models/gsa/__pycache__/__init__.cpython-312.pyc +0 -0
  10. fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc +0 -0
  11. fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc +0 -0
  12. fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc +0 -0
  13. fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc +0 -0
  14. fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-312.pyc +0 -0
  15. fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc +0 -0
  16. fla/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc +0 -0
  17. fla/models/mamba2/configuration_mamba2.py +170 -0
  18. fla/models/retnet/__pycache__/__init__.cpython-312.pyc +0 -0
  19. fla/models/retnet/__pycache__/modeling_retnet.cpython-312.pyc +0 -0
  20. fla/models/rwkv6/__pycache__/configuration_rwkv6.cpython-312.pyc +0 -0
  21. fla/models/rwkv6/__pycache__/modeling_rwkv6.cpython-312.pyc +0 -0
  22. fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc +0 -0
  23. fla/models/rwkv7/__pycache__/modeling_rwkv7.cpython-312.pyc +0 -0
  24. fla/models/transformer/__pycache__/__init__.cpython-312.pyc +0 -0
  25. fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  26. fla/models/transformer_dsmtp/__pycache__/__init__.cpython-312.pyc +0 -0
  27. fla/models/transformer_dsmtp/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  28. fla/models/transformer_dsmtp/__pycache__/modeling_transformer.cpython-312.pyc +0 -0
  29. fla/models/transformer_top/__pycache__/configuration_transformer.cpython-312.pyc +0 -0
  30. fla/ops/__pycache__/__init__.cpython-312.pyc +0 -0
  31. fla/ops/abc/chunk.py +1116 -0
  32. fla/ops/based/__init__.py +9 -0
  33. fla/ops/common/__init__.py +1 -0
  34. fla/ops/common/chunk_delta_h.py +399 -0
  35. fla/ops/common/chunk_h.py +422 -0
  36. fla/ops/common/chunk_o.py +668 -0
  37. fla/ops/common/chunk_scaled_dot_kkt.py +126 -0
  38. fla/ops/common/fused_recurrent.py +575 -0
  39. fla/ops/common/utils.py +69 -0
  40. fla/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  41. fla/ops/delta_rule/fused_chunk.py +6 -0
  42. fla/ops/forgetting_attn/parallel.py +708 -0
  43. fla/ops/gated_delta_rule/__init__.py +7 -0
  44. fla/ops/gated_delta_rule/chunk.py +392 -0
  45. fla/ops/generalized_delta_rule/__init__.py +9 -0
  46. fla/ops/gla/fused_recurrent.py +113 -0
  47. fla/ops/gla/naive.py +41 -0
  48. fla/ops/gsa/__init__.py +9 -0
  49. fla/ops/hgrn/__init__.py +9 -0
  50. fla/ops/hgrn/fused_recurrent.py +308 -0
fla/models/delta_net/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (692 Bytes). View file
 
fla/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc ADDED
Binary file (3.58 kB). View file
 
fla/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-312.pyc ADDED
Binary file (3.33 kB). View file
 
fla/models/gated_deltanet/modeling_gated_deltanet.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gated_deltanet import GatedDeltaNet
20
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GatedDeltaNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class GatedDeltaNetBlock(nn.Module):
34
+ def __init__(self, config: GatedDeltaNetConfig, layer_idx: int):
35
+ super().__init__()
36
+
37
+ self.config = config
38
+ self.layer_idx = layer_idx
39
+
40
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
41
+ if config.attn is not None and layer_idx in config.attn['layers']:
42
+ self.attn = Attention(
43
+ hidden_size=config.hidden_size,
44
+ num_heads=config.attn['num_heads'],
45
+ num_kv_heads=config.attn['num_kv_heads'],
46
+ qkv_bias=config.attn['qkv_bias'],
47
+ window_size=config.attn['window_size'],
48
+ rope_theta=config.attn['rope_theta'],
49
+ max_position_embeddings=config.max_position_embeddings,
50
+ layer_idx=layer_idx
51
+ )
52
+ else:
53
+ self.attn = GatedDeltaNet(
54
+ mode=config.attn_mode,
55
+ hidden_size=config.hidden_size,
56
+ expand_v=config.expand_v,
57
+ head_dim=config.head_dim,
58
+ num_heads=config.num_heads,
59
+ use_gate=config.use_gate,
60
+ use_short_conv=config.use_short_conv,
61
+ conv_size=config.conv_size,
62
+ norm_eps=config.norm_eps,
63
+ layer_idx=layer_idx
64
+ )
65
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
66
+ self.mlp = GatedDeltaNetMLP(
67
+ hidden_size=config.hidden_size,
68
+ hidden_ratio=config.hidden_ratio,
69
+ intermediate_size=config.intermediate_size,
70
+ hidden_act=config.hidden_act,
71
+ fuse_swiglu=config.fuse_swiglu
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ hidden_states: torch.Tensor,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
79
+ use_cache: Optional[bool] = False,
80
+ output_attentions: Optional[bool] = False,
81
+ **kwargs: Unpack[Dict]
82
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
83
+ residual = hidden_states
84
+ hidden_states = self.attn_norm(hidden_states)
85
+ hidden_states, attentions, past_key_values = self.attn(
86
+ hidden_states=hidden_states,
87
+ attention_mask=attention_mask,
88
+ past_key_values=past_key_values,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ **kwargs
92
+ )
93
+ if self.config.fuse_norm:
94
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
95
+ else:
96
+ hidden_states = residual + hidden_states
97
+ residual = hidden_states
98
+ hidden_states = self.mlp_norm(hidden_states)
99
+ hidden_states = self.mlp(hidden_states, **kwargs)
100
+ hidden_states = residual + hidden_states
101
+
102
+ outputs = (hidden_states, attentions, past_key_values)
103
+
104
+ return outputs
105
+
106
+
107
+ class GatedDeltaNetPreTrainedModel(PreTrainedModel):
108
+
109
+ config_class = GatedDeltaNetConfig
110
+ base_model_prefix = 'model'
111
+ supports_gradient_checkpointing = True
112
+ _no_split_modules = ['GatedDeltaNetBlock']
113
+ _supports_cache_class = True
114
+
115
+ def __init__(self, *inputs, **kwargs):
116
+ super().__init__(*inputs, **kwargs)
117
+
118
+ def _init_weights(
119
+ self,
120
+ module: nn.Module,
121
+ prenorm_residual_strategy: Optional[str] = 'rescale',
122
+ num_residuals_per_layer: int = 2,
123
+ ):
124
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
125
+ # Slightly different from the TF version which uses truncated_normal for initialization
126
+ # cf https://github.com/pytorch/pytorch/pull/5617
127
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
128
+ if module.bias is not None:
129
+ nn.init.zeros_(module.bias)
130
+ elif isinstance(module, nn.Embedding):
131
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
132
+ elif hasattr(module, 'reset_parameters'):
133
+ module.reset_parameters()
134
+
135
+ if prenorm_residual_strategy is not None:
136
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
137
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
138
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
139
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
140
+ #
141
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
142
+ p = None
143
+ if hasattr(module, 'o_proj'):
144
+ p = module.o_proj.weight
145
+ elif hasattr(module, 'down_proj'):
146
+ p = module.down_proj.weight
147
+ if p is not None:
148
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
149
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
150
+ # We need to reinit p since this code could be called multiple times
151
+ # Having just p *= scale would repeatedly scale it down
152
+ if prenorm_residual_strategy == 'rescale':
153
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
154
+ with torch.no_grad():
155
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
156
+ elif prenorm_residual_strategy == 'zero':
157
+ nn.init.zeros_(p)
158
+ else:
159
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
160
+
161
+
162
+ class GatedDeltaNetModel(GatedDeltaNetPreTrainedModel):
163
+
164
+ def __init__(self, config: GatedDeltaNetConfig):
165
+ super().__init__(config)
166
+ self.padding_idx = config.pad_token_id
167
+ self.vocab_size = config.vocab_size
168
+
169
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
170
+ self.layers = nn.ModuleList([GatedDeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
171
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
172
+
173
+ self.gradient_checkpointing = False
174
+
175
+ self.post_init()
176
+
177
+ def get_input_embeddings(self):
178
+ return self.embeddings
179
+
180
+ def set_input_embeddings(self, value):
181
+ self.embeddings = value
182
+
183
+ def forward(
184
+ self,
185
+ input_ids: Optional[torch.LongTensor] = None,
186
+ attention_mask: Optional[torch.Tensor] = None, # noqa
187
+ inputs_embeds: Optional[torch.FloatTensor] = None,
188
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
189
+ use_cache: Optional[bool] = None,
190
+ output_attentions: Optional[bool] = None,
191
+ output_hidden_states: Optional[bool] = None,
192
+ return_dict: Optional[bool] = None,
193
+ **kwargs: Unpack[Dict]
194
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
195
+ if output_attentions:
196
+ warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.")
197
+ output_attentions = False
198
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
199
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
200
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
201
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
202
+
203
+ # retrieve input_ids and inputs_embeds
204
+ if input_ids is not None and inputs_embeds is not None:
205
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
206
+ if input_ids is None and inputs_embeds is None:
207
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
208
+
209
+ if inputs_embeds is None:
210
+ inputs_embeds = self.embeddings(input_ids)
211
+ hidden_states = inputs_embeds
212
+
213
+ if use_cache and not isinstance(past_key_values, Cache):
214
+ past_key_values = Cache.from_legacy_cache(past_key_values)
215
+
216
+ if self.gradient_checkpointing and self.training and use_cache:
217
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
218
+ use_cache = False
219
+
220
+ all_hidden_states = () if output_hidden_states else None
221
+ all_attns = () if output_attentions else None
222
+ for layer in self.layers:
223
+ if output_hidden_states:
224
+ all_hidden_states += (hidden_states,)
225
+
226
+ if self.gradient_checkpointing and self.training:
227
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
228
+ layer.__call__,
229
+ hidden_states,
230
+ attention_mask,
231
+ past_key_values,
232
+ use_cache,
233
+ output_attentions,
234
+ **kwargs
235
+ )
236
+ else:
237
+ hidden_states, attentions, past_key_values = layer(
238
+ hidden_states,
239
+ attention_mask=attention_mask,
240
+ past_key_values=past_key_values,
241
+ use_cache=use_cache,
242
+ output_attentions=output_attentions,
243
+ **kwargs
244
+ )
245
+
246
+ if output_attentions:
247
+ all_attns += (attentions,)
248
+
249
+ hidden_states = self.norm(hidden_states)
250
+
251
+ # add hidden states from the last decoder layer
252
+ if output_hidden_states:
253
+ all_hidden_states += (hidden_states,)
254
+
255
+ if not return_dict:
256
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
257
+ return BaseModelOutputWithPast(
258
+ last_hidden_state=hidden_states,
259
+ past_key_values=past_key_values,
260
+ hidden_states=all_hidden_states,
261
+ attentions=all_attns
262
+ )
263
+
264
+
265
+ class GatedDeltaNetForCausalLM(GatedDeltaNetPreTrainedModel, GenerationMixin):
266
+
267
+ _tied_weights_keys = ["lm_head.weight"]
268
+
269
+ def __init__(self, config):
270
+ super().__init__(config)
271
+ self.model = GatedDeltaNetModel(config)
272
+ self.vocab_size = config.vocab_size
273
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
274
+ self.criterion = None
275
+
276
+ # Initialize weights and apply final processing
277
+ self.post_init()
278
+
279
+ def get_input_embeddings(self):
280
+ return self.model.embeddings
281
+
282
+ def set_input_embeddings(self, value):
283
+ self.model.embeddings = value
284
+
285
+ def get_output_embeddings(self):
286
+ return self.lm_head
287
+
288
+ def set_output_embeddings(self, new_embeddings):
289
+ self.lm_head = new_embeddings
290
+
291
+ def set_decoder(self, decoder):
292
+ self.model = decoder
293
+
294
+ def get_decoder(self):
295
+ return self.model
296
+
297
+ def generate(self, *args, **kwargs):
298
+ try:
299
+ return super().generate(*args, **kwargs)
300
+ except AttributeError as exception:
301
+ if 'past_key_values' in str(exception):
302
+ raise AttributeError(
303
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
304
+ f"which is not supported for {self.__class__.__name__}. "
305
+ f"Try another generation strategy instead. "
306
+ f"For the available generation strategies, check this doc: "
307
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
308
+ )
309
+ else:
310
+ raise exception
311
+
312
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
313
+ def prepare_inputs_for_generation(
314
+ self,
315
+ input_ids: torch.LongTensor = None,
316
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ inputs_embeds: Optional[torch.Tensor] = None,
319
+ use_cache: bool = True,
320
+ logits_to_keep: Optional[int] = None,
321
+ **kwargs
322
+ ):
323
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
324
+ if past_key_values is not None and len(past_key_values) > 0:
325
+ input_ids = input_ids[:, -1:]
326
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
327
+ if inputs_embeds is not None and len(past_key_values) == 0:
328
+ model_inputs = {'inputs_embeds': inputs_embeds}
329
+ else:
330
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
331
+ # recompiles graphs as the stride of the inputs is a guard.
332
+ # Ref: https://github.com/huggingface/transformers/pull/29114
333
+ # TODO: use `next_tokens` directly instead.
334
+ model_inputs = {'input_ids': input_ids.contiguous()}
335
+
336
+ if logits_to_keep is not None:
337
+ model_inputs['logits_to_keep'] = logits_to_keep
338
+
339
+ model_inputs.update({
340
+ 'past_key_values': past_key_values,
341
+ 'use_cache': use_cache,
342
+ 'attention_mask': attention_mask,
343
+ })
344
+ return model_inputs
345
+
346
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
347
+ def forward(
348
+ self,
349
+ input_ids: torch.LongTensor = None,
350
+ attention_mask: Optional[torch.Tensor] = None,
351
+ inputs_embeds: Optional[torch.Tensor] = None,
352
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
353
+ labels: Optional[torch.LongTensor] = None,
354
+ use_cache: Optional[bool] = None,
355
+ output_attentions: Optional[bool] = None,
356
+ output_hidden_states: Optional[bool] = None,
357
+ return_dict: Optional[bool] = None,
358
+ logits_to_keep: Optional[int] = 0,
359
+ **kwargs: Unpack[Dict]
360
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
361
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
362
+ output_hidden_states = (
363
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
364
+ )
365
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
366
+
367
+ outputs = self.model(
368
+ input_ids=input_ids,
369
+ attention_mask=attention_mask,
370
+ inputs_embeds=inputs_embeds,
371
+ past_key_values=past_key_values,
372
+ use_cache=use_cache,
373
+ output_attentions=output_attentions,
374
+ output_hidden_states=output_hidden_states,
375
+ return_dict=return_dict,
376
+ **kwargs
377
+ )
378
+
379
+ hidden_states = outputs[0]
380
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
381
+
382
+ loss, logits = None, None
383
+ if not fuse_linear_and_cross_entropy or labels is None:
384
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
385
+ if labels is not None:
386
+ if getattr(self, 'criterion', None) is None:
387
+ if fuse_linear_and_cross_entropy:
388
+ criterion = FusedLinearCrossEntropyLoss()
389
+ elif self.config.fuse_cross_entropy:
390
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
391
+ else:
392
+ criterion = nn.CrossEntropyLoss()
393
+ else:
394
+ criterion = self.criterion
395
+ labels = labels.to(hidden_states.device)
396
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
397
+ if fuse_linear_and_cross_entropy:
398
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
399
+ else:
400
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
401
+
402
+ if not return_dict:
403
+ output = (logits,) + outputs[1:]
404
+ return (loss,) + output if loss is not None else output
405
+
406
+ return CausalLMOutputWithPast(
407
+ loss=loss,
408
+ logits=logits,
409
+ past_key_values=outputs.past_key_values,
410
+ hidden_states=outputs.hidden_states,
411
+ attentions=outputs.attentions,
412
+ )
fla/models/gated_deltaproduct/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (768 Bytes). View file
 
fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-312.pyc ADDED
Binary file (3.37 kB). View file
 
fla/models/gla/__pycache__/configuration_gla.cpython-312.pyc ADDED
Binary file (3.72 kB). View file
 
fla/models/gsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (648 Bytes). View file
 
fla/models/gsa/__pycache__/configuration_gsa.cpython-312.pyc ADDED
Binary file (3.83 kB). View file
 
fla/models/hgrn2/__pycache__/modeling_hgrn2.cpython-312.pyc ADDED
Binary file (18.9 kB). View file
 
fla/models/lightnet/__pycache__/configuration_lightnet.cpython-312.pyc ADDED
Binary file (3.35 kB). View file
 
fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc ADDED
Binary file (18.3 kB). View file
 
fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-312.pyc ADDED
Binary file (3.64 kB). View file
 
fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc ADDED
Binary file (7.05 kB). View file
 
fla/models/mamba2/configuration_mamba2.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """MAMBA2 configuration"""
15
+
16
+ import math
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+
20
+
21
+ class Mamba2Config(PretrainedConfig):
22
+ """
23
+ This is the configuration class to store the configuration of a [`Mamba2Model`]. It is used to instantiate a MAMBA2
24
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
25
+ defaults will yield a similar configuration to that of the MAMBA2
26
+ [state-spaces/mamba2-2.8b](https://huggingface.co/state-spaces/mamba2-2.8b) architecture.
27
+
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+
32
+ Args:
33
+ num_heads (`int`, *optional*, defaults to 64):
34
+ Number of heads for the evolution matrices of mamba 2.
35
+ head_dim (`int`, *optional*, defaults to 64):
36
+ Dimension of each head.
37
+ vocab_size (`int`, *optional*, defaults to 32768):
38
+ Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`Mamba2Model`].
40
+ hidden_size (`int`, *optional*, defaults to 2048):
41
+ Dimensionality of the embeddings and hidden states.
42
+ state_size (`int`, *optional*, defaults to 128): shape of the state space latents.
43
+ num_hidden_layers (`int`, *optional*, defaults to 48):
44
+ Number of hidden layers in the model.
45
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
46
+ The epsilon to use in the layer normalization layers.
47
+ pad_token_id (`int`, *optional*, defaults to 0):
48
+ Padding token id.
49
+ bos_token_id (`int`, *optional*, defaults to 1):
50
+ The id of the beginning of sentence token in the vocabulary.
51
+ eos_token_id (`int`, *optional*, defaults to 2):
52
+ The id of the end of sentence token in the vocabulary.
53
+ expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
54
+ conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
55
+ n_groups (`int`, *optional*, defaults to 1):
56
+ Number of groups for the evolution matrices of mamba 2.
57
+ use_bias (`bool`, *optional*, defaults to `False`):
58
+ Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
59
+ use_conv_bias (`bool`, *optional*, defaults to `True`):
60
+ Whether or not to use bias in the convolution layer of the mixer block.
61
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
62
+ The non-linear activation function (function or string) in the decoder.
63
+ initializer_range (`float`, *optional*, defaults to 0.1):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ residual_in_fp32 (`bool`, *optional*, defaults to `True`):
66
+ Whether or not residuals should be in `float32`.
67
+ If set to `False` residuals will keep the same `dtype` as the rest of the model
68
+ time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
69
+ Rank of the discretization projection matrix.
70
+ `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
71
+ time_step_min (`float`, *optional*, defaults to 0.001):
72
+ Minimum `time_step` used to bound `dt_proj.bias`.
73
+ time_step_max (`float`, *optional*, defaults to 0.1):
74
+ Maximum `time_step` used to bound `dt_proj.bias`.
75
+ time_step_floor (`float`, *optional*, defaults to 0.0001):
76
+ Minimum clamping value of the `dt_proj.bias` layer initialization.
77
+ time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`):
78
+ Accepted range of time step values.
79
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
80
+ Whether or not to rescale `out_proj` weights when initializing.
81
+ use_cache (`bool`, *optional*, defaults to `True`):
82
+ Whether or not the cache should be used.
83
+ rms_norm (`bool`, *optional*, defaults to `True`):
84
+ Whether to use RMS norm or not.
85
+ chunk_size (`int`, *optional*, defaults to 256):
86
+ Size of the chunks that will comprise the sequence.
87
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
88
+ Whether to tie word embeddings or not.
89
+ """
90
+
91
+ model_type = "mamba2"
92
+
93
+ def __init__(
94
+ self,
95
+ num_heads: int = 64,
96
+ head_dim: int = 64,
97
+ vocab_size: int = 32000,
98
+ hidden_size: int = 2048,
99
+ state_size: int = 128,
100
+ num_hidden_layers: int = 48,
101
+ layer_norm_epsilon: float = 1e-5,
102
+ pad_token_id: int = 0,
103
+ bos_token_id: int = 1,
104
+ eos_token_id: int = 2,
105
+ expand: int = 2,
106
+ conv_kernel: int = 4,
107
+ n_groups: int = 1,
108
+ use_bias: bool = False,
109
+ use_conv_bias: bool = True,
110
+ hidden_act: str = "silu",
111
+ initializer_range: float = 0.1,
112
+ residual_in_fp32: bool = True,
113
+ time_step_rank: str = "auto",
114
+ time_step_min: float = 0.001,
115
+ time_step_max: float = 0.1,
116
+ time_step_floor: float = 1e-4,
117
+ time_step_limit=(0.0, float("inf")),
118
+ rescale_prenorm_residual: bool = True,
119
+ use_cache: bool = True,
120
+ rms_norm: bool = True,
121
+ chunk_size: int = 256,
122
+ fuse_norm: bool = True,
123
+ fuse_cross_entropy: bool = True,
124
+ tie_word_embeddings: bool = False,
125
+ **kwargs,
126
+ ):
127
+ self.vocab_size = vocab_size
128
+ self.hidden_size = hidden_size
129
+ self.state_size = state_size
130
+ self.num_hidden_layers = num_hidden_layers
131
+ self.layer_norm_epsilon = layer_norm_epsilon
132
+ self.conv_kernel = conv_kernel
133
+ self.expand = expand
134
+
135
+ self.bos_token_id = bos_token_id
136
+ self.eos_token_id = eos_token_id
137
+ self.pad_token_id = pad_token_id
138
+ self.use_bias = use_bias
139
+ self.use_conv_bias = use_conv_bias
140
+ self.hidden_act = hidden_act
141
+ self.initializer_range = initializer_range
142
+ self.time_step_rank = (
143
+ math.ceil(self.hidden_size / 16)
144
+ if time_step_rank == "auto"
145
+ else time_step_rank
146
+ )
147
+ self.time_step_min = time_step_min
148
+ self.time_step_max = time_step_max
149
+ self.time_step_floor = time_step_floor
150
+ self.rescale_prenorm_residual = rescale_prenorm_residual
151
+ self.residual_in_fp32 = residual_in_fp32
152
+ self.use_cache = use_cache
153
+ self.n_groups = n_groups
154
+ self.num_heads = num_heads
155
+ self.head_dim = head_dim
156
+ self.rms_norm = rms_norm
157
+ self.state_size = state_size
158
+ self.chunk_size = chunk_size
159
+ self.time_step_limit = time_step_limit
160
+ self.fuse_norm = fuse_norm
161
+ self.fuse_cross_entropy = fuse_cross_entropy
162
+ self.tie_word_embeddings = tie_word_embeddings
163
+
164
+ super().__init__(
165
+ bos_token_id=bos_token_id,
166
+ eos_token_id=eos_token_id,
167
+ pad_token_id=pad_token_id,
168
+ tie_word_embeddings=tie_word_embeddings,
169
+ **kwargs,
170
+ )
fla/models/retnet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (673 Bytes). View file
 
fla/models/retnet/__pycache__/modeling_retnet.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
fla/models/rwkv6/__pycache__/configuration_rwkv6.cpython-312.pyc ADDED
Binary file (3.31 kB). View file
 
fla/models/rwkv6/__pycache__/modeling_rwkv6.cpython-312.pyc ADDED
Binary file (21.1 kB). View file
 
fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (678 Bytes). View file
 
fla/models/rwkv7/__pycache__/modeling_rwkv7.cpython-312.pyc ADDED
Binary file (22.3 kB). View file
 
fla/models/transformer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (719 Bytes). View file
 
fla/models/transformer/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.51 kB). View file
 
fla/models/transformer_dsmtp/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (752 Bytes). View file
 
fla/models/transformer_dsmtp/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.61 kB). View file
 
fla/models/transformer_dsmtp/__pycache__/modeling_transformer.cpython-312.pyc ADDED
Binary file (20.6 kB). View file
 
fla/models/transformer_top/__pycache__/configuration_transformer.cpython-312.pyc ADDED
Binary file (2.69 kB). View file
 
fla/ops/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.89 kB). View file
 
fla/ops/abc/chunk.py ADDED
@@ -0,0 +1,1116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils import logcumsumexp_fwd_kernel, softmax_bwd, softmax_fwd
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.jit(do_not_specialize=['T'])
16
+ def chunk_abc_fwd_kernel_h(
17
+ k,
18
+ v,
19
+ z,
20
+ h,
21
+ h0,
22
+ ht,
23
+ T,
24
+ K: tl.constexpr,
25
+ V: tl.constexpr,
26
+ BT: tl.constexpr,
27
+ BK: tl.constexpr,
28
+ BV: tl.constexpr,
29
+ NT: tl.constexpr,
30
+ NORMK: tl.constexpr,
31
+ USE_INITIAL_STATE: tl.constexpr,
32
+ STORE_FINAL_STATE: tl.constexpr
33
+ ):
34
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
35
+
36
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
37
+ if USE_INITIAL_STATE:
38
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
39
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
40
+ if NORMK:
41
+ p_z0 = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_k * BK,), (BK,), (0,))
42
+ else:
43
+ p_z0 = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_v * BV,), (BV,), (0,))
44
+ b_zp = tl.load(p_z0).to(tl.float32)
45
+ for i_t in range(NT):
46
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
47
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
48
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
49
+
50
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
51
+ # [BK, BT]
52
+ b_k = tl.load(p_k, boundary_check=(0, 1))
53
+ # [BT, BV]
54
+ b_v = tl.load(p_v, boundary_check=(0, 1))
55
+ if NORMK:
56
+ p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
57
+ # [BK,]
58
+ b_zc = tl.load(p_zc, boundary_check=(0,))
59
+ b_r, b_zp = exp(b_zp - b_zc), b_zc
60
+ # [BK, BV]
61
+ b_h = b_h * b_r[:, None]
62
+ b_k = exp(b_k - b_zc[:, None]).to(b_k.dtype)
63
+ else:
64
+ p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))
65
+ # [BV,]
66
+ b_zc = tl.load(p_zc, boundary_check=(0,))
67
+ b_r, b_zp = exp(b_zp - b_zc), b_zc
68
+ # [BK, BV]
69
+ b_h = b_h * b_r[None, :]
70
+ b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype)
71
+ # [BK, BV]
72
+ b_h += tl.dot(b_k, b_v, allow_tf32=False)
73
+
74
+ if STORE_FINAL_STATE:
75
+ p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
76
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
77
+
78
+
79
+ @triton.jit(do_not_specialize=['T'])
80
+ def chunk_abc_fwd_kernel_intra_K(
81
+ v,
82
+ z,
83
+ o,
84
+ A,
85
+ T,
86
+ V: tl.constexpr,
87
+ BT: tl.constexpr,
88
+ BC: tl.constexpr,
89
+ BV: tl.constexpr,
90
+ NC: tl.constexpr
91
+ ):
92
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
93
+ i_t, i_i = i_c // NC, i_c % NC
94
+
95
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))
97
+ # [BV,]
98
+ b_zn = tl.load(p_zn, boundary_check=(0,))
99
+ # [BC, BV]
100
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
101
+ for i_j in range(0, i_i):
102
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
103
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
104
+ # [BC, BV]
105
+ b_v = tl.load(p_v, boundary_check=(0, 1))
106
+ # [BC, BC]
107
+ b_A = tl.load(p_A, boundary_check=(0, 1))
108
+ b_o += tl.dot(b_A, exp(b_v - b_zn[None, :]).to(b_v.dtype), allow_tf32=False)
109
+ b_z = tl.load(p_z, boundary_check=(0, 1))
110
+ b_o *= exp(b_zn[None, :] - b_z)
111
+
112
+ o_i = tl.arange(0, BC)
113
+ o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
114
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
115
+ for j in range(0, BC):
116
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
117
+ # [BC,]
118
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
119
+ # [BV,]
120
+ b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)
121
+ # [BC, BV]
122
+ # avoid 0 * inf = inf
123
+ m_i = o_i[:, None] >= j
124
+ b_o += tl.where(m_i, b_A[:, None] * exp(b_v[None, :] - b_z), 0)
125
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
126
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
127
+
128
+
129
+ @triton.jit(do_not_specialize=['T'])
130
+ def chunk_abc_fwd_kernel_K(
131
+ q,
132
+ k,
133
+ z,
134
+ h,
135
+ o,
136
+ A,
137
+ scale,
138
+ T,
139
+ K: tl.constexpr,
140
+ V: tl.constexpr,
141
+ BT: tl.constexpr,
142
+ BK: tl.constexpr,
143
+ BV: tl.constexpr,
144
+ NT: tl.constexpr
145
+ ):
146
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
147
+ i_p = tl.maximum(i_t * BT - 1, 0)
148
+
149
+ o_i = tl.arange(0, BT)
150
+ m_s = o_i[:, None] >= o_i[None, :]
151
+
152
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
153
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
154
+ for i_k in range(tl.cdiv(K, BK)):
155
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
156
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
157
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
158
+
159
+ # [BT, BK]
160
+ b_q = tl.load(p_q, boundary_check=(0, 1))
161
+ b_q = (b_q * scale).to(b_q.dtype)
162
+ # [BK, BT]
163
+ b_k = tl.load(p_k, boundary_check=(0, 1))
164
+ # [BK, BV]
165
+ b_h = tl.load(p_h, boundary_check=(0, 1))
166
+ # [BT, BV]
167
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
168
+ # [BT, BT]
169
+ b_A += tl.dot(b_q, b_k, allow_tf32=False)
170
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
171
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
172
+ # [BT, BV]
173
+ b_z = tl.load(p_z, boundary_check=(0, 1))
174
+ # [BT, BV]
175
+ p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,))
176
+ b_zp = tl.load(p_zp, boundary_check=(0,))
177
+ b_o = b_o * exp(b_zp[None, :] - b_z)
178
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
179
+
180
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
181
+ # [BT, BT]
182
+ b_A = tl.where(m_s, b_A, 0.)
183
+ if i_v == 0:
184
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
185
+
186
+
187
+ @triton.jit(do_not_specialize=['T'])
188
+ def chunk_abc_fwd_kernel_intra_V(
189
+ q,
190
+ k,
191
+ z,
192
+ A,
193
+ scale,
194
+ T,
195
+ K: tl.constexpr,
196
+ BT: tl.constexpr,
197
+ BC: tl.constexpr,
198
+ BK: tl.constexpr,
199
+ NC: tl.constexpr
200
+ ):
201
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
202
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
203
+ n_bh = tl.num_programs(2)
204
+
205
+ if i_i > i_j:
206
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
207
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
208
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
209
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
210
+ p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
211
+ # [BK,]
212
+ b_zn = tl.load(p_zn, boundary_check=(0,))
213
+ # [BC, BK]
214
+ b_q = tl.load(p_q, boundary_check=(0, 1))
215
+ b_z = tl.load(p_z, boundary_check=(0, 1))
216
+ b_q = (b_q * exp(b_zn[None, :] - b_z) * scale).to(b_q.dtype)
217
+ # [BK, BC]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ b_k = exp(b_k - b_zn[:, None]).to(b_k.dtype)
220
+ # [BC, BC]
221
+ b_A = tl.dot(b_q, b_k, allow_tf32=False)
222
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
223
+ elif i_i == i_j:
224
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
225
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,))
226
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
227
+ # [BC, BK]
228
+ b_q = tl.load(p_q, boundary_check=(0, 1))
229
+ b_z = tl.load(p_z, boundary_check=(0, 1))
230
+
231
+ o_i = tl.arange(0, BC)
232
+ o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
233
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
234
+ for j in range(0, BC):
235
+ # [BK,]
236
+ b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32)
237
+ # [BC,]
238
+ b_A = tl.sum(b_q * exp(b_k[None, :] - b_z) * scale, 1)
239
+ b_A = tl.where(o_i >= j, b_A, 0.)
240
+ tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A)
241
+
242
+ p_k = tl.advance(p_k, (K,))
243
+
244
+
245
+ @triton.jit(do_not_specialize=['T'])
246
+ def chunk_abc_fwd_kernel_V(
247
+ q,
248
+ v,
249
+ z,
250
+ h,
251
+ o,
252
+ A,
253
+ scale,
254
+ T,
255
+ K: tl.constexpr,
256
+ V: tl.constexpr,
257
+ BT: tl.constexpr,
258
+ BK: tl.constexpr,
259
+ BV: tl.constexpr,
260
+ NT: tl.constexpr
261
+ ):
262
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
263
+ i_p = tl.maximum(i_t * BT - 1, 0)
264
+
265
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
266
+ for i_k in range(tl.cdiv(K, BK)):
267
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
268
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
269
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
270
+ p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,))
271
+
272
+ # [BT, BK]
273
+ b_q = tl.load(p_q, boundary_check=(0, 1))
274
+ b_q = (b_q * scale).to(b_q.dtype)
275
+ # [BT, BK]
276
+ b_z = tl.load(p_z, boundary_check=(0, 1))
277
+ # [BT, BK]
278
+ b_zp = tl.load(p_zp, boundary_check=(0,))
279
+ b_q = (b_q * exp(b_zp[None, :] - b_z)).to(b_q.dtype)
280
+ # [BK, BV]
281
+ b_h = tl.load(p_h, boundary_check=(0, 1))
282
+ # works but dkw, owing to divine benevolence
283
+ # [BT, BV]
284
+ if i_k >= 0:
285
+ b_o += tl.dot(b_q, b_h, allow_tf32=False)
286
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
287
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
288
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
289
+ # [BT, BV]
290
+ b_v = tl.load(p_v, boundary_check=(0, 1))
291
+ # [BT, BT]
292
+ b_A = tl.load(p_A, boundary_check=(0, 1))
293
+ b_o += tl.dot(b_A.to(b_v.dtype), b_v, allow_tf32=False)
294
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
295
+
296
+
297
+ @triton.jit(do_not_specialize=['T'])
298
+ def chunk_abc_bwd_kernel_dh(
299
+ q,
300
+ z,
301
+ do,
302
+ dh,
303
+ scale,
304
+ T,
305
+ K: tl.constexpr,
306
+ V: tl.constexpr,
307
+ BT: tl.constexpr,
308
+ BK: tl.constexpr,
309
+ BV: tl.constexpr,
310
+ NT: tl.constexpr,
311
+ NORMK: tl.constexpr
312
+ ):
313
+ i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
314
+
315
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
316
+ b_zp = tl.full([BK if NORMK else BV], float('inf'), dtype=tl.float32)
317
+ for i_t in range(NT - 1, -1, -1):
318
+ i_p = tl.maximum(i_t * BT - 1, 0)
319
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
320
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
321
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
322
+
323
+ # [BK, BT]
324
+ b_q = tl.load(p_q, boundary_check=(0, 1))
325
+ b_q = (b_q * scale).to(b_q.dtype)
326
+ # [BT, BV]
327
+ b_do = tl.load(p_do, boundary_check=(0, 1))
328
+
329
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
330
+ if NORMK:
331
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
332
+ p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,))
333
+ # [BK,]
334
+ b_zc = tl.load(p_zc, boundary_check=(0,))
335
+ b_r, b_zp = exp(b_zc - b_zp), b_zc
336
+ # [BK, BT]
337
+ b_z = tl.load(p_z, boundary_check=(0, 1))
338
+ b_q = (b_q * exp(b_zc[:, None] - b_z)).to(b_q.dtype)
339
+ # [BK, BV]
340
+ b_dh = b_dh * b_r[:, None]
341
+ else:
342
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
343
+ p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,))
344
+ # [BV,]
345
+ b_zc = tl.load(p_zc, boundary_check=(0,))
346
+ b_r, b_zp = exp(b_zc - b_zp), b_zc
347
+ # [BT, BV]
348
+ b_z = tl.load(p_z, boundary_check=(0,))
349
+ b_do = (b_do * exp(b_zc[None, :] - b_z)).to(b_do.dtype)
350
+ # [BK, BV]
351
+ b_dh = b_dh * b_r[None, :]
352
+ # [BK, BV]
353
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
354
+
355
+
356
+ @triton.jit(do_not_specialize=['T'])
357
+ def chunk_abc_bwd_kernel_V(
358
+ k,
359
+ v,
360
+ z,
361
+ h,
362
+ A,
363
+ do,
364
+ dh,
365
+ dq,
366
+ dk,
367
+ dv,
368
+ dA,
369
+ scale,
370
+ T,
371
+ K: tl.constexpr,
372
+ V: tl.constexpr,
373
+ BT: tl.constexpr,
374
+ BK: tl.constexpr,
375
+ BV: tl.constexpr,
376
+ NT: tl.constexpr
377
+ ):
378
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
379
+ i_p = tl.maximum(i_t * BT - 1, 0)
380
+ n_bh = tl.num_programs(2)
381
+
382
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
383
+ p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))
384
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
385
+
386
+ # [BK,]
387
+ b_zc = tl.load(p_zc, boundary_check=(0,))
388
+ # [BT, BK]
389
+ b_k = tl.load(p_k, boundary_check=(0, 1))
390
+ b_k = exp(b_k - b_zc[None, :]).to(b_k.dtype)
391
+ # [BT, BT]
392
+ b_A = tl.load(p_A, boundary_check=(0, 1))
393
+
394
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
395
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
396
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
397
+ for i_v in range(tl.cdiv(V, BV)):
398
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
399
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * V * K, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
400
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
401
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
402
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
403
+
404
+ # [BT, BV]
405
+ b_v = tl.load(p_v, boundary_check=(0, 1))
406
+ # [BV, BK]
407
+ b_h = tl.load(p_h, boundary_check=(0, 1))
408
+ # [BT, BV]
409
+ b_do = tl.load(p_do, boundary_check=(0, 1))
410
+ # [BK, BV]
411
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
412
+
413
+ # [BT, BV]
414
+ b_dv = tl.dot(b_k, b_dh, allow_tf32=False)
415
+ if i_k == 0:
416
+ b_dv += tl.dot(b_A.to(b_do.dtype), b_do, allow_tf32=False)
417
+ b_do = (b_do * scale).to(b_do.dtype)
418
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
419
+ # [BT, BT]
420
+ b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)
421
+ # [BT, BK]
422
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
423
+ # [BT, BK]
424
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
425
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
426
+ p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,))
427
+ # [BK,]
428
+ b_zp = tl.load(p_zp, boundary_check=(0,))
429
+ # [BT, BK]
430
+ b_z = tl.load(p_z, boundary_check=(0, 1))
431
+ b_z = exp(b_zp[None, :] - b_z)
432
+ # [BT, BK]
433
+ b_dq = b_dq * b_z
434
+ b_dk = b_dk * b_k
435
+
436
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
437
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
438
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT,), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
439
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
440
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
441
+
442
+ o_i = tl.arange(0, BT)
443
+ m_s = o_i[:, None] >= o_i[None, :]
444
+ # [BT, BT]
445
+ b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype)
446
+ if i_k == 0:
447
+ tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
448
+
449
+
450
+ @triton.jit(do_not_specialize=['T'])
451
+ def chunk_abc_bwd_kernel_intra_V(
452
+ q,
453
+ k,
454
+ z,
455
+ dA,
456
+ dq,
457
+ dk,
458
+ T,
459
+ K: tl.constexpr,
460
+ BT: tl.constexpr,
461
+ BC: tl.constexpr,
462
+ BK: tl.constexpr,
463
+ NC: tl.constexpr
464
+ ):
465
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
466
+ i_t, i_i = i_c // NC, i_c % NC
467
+
468
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
469
+ p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,))
470
+ # [BK,]
471
+ b_zn = tl.load(p_zn, boundary_check=(0,))
472
+ # [BC, BK]
473
+ b_z = tl.load(p_z, boundary_check=(0, 1))
474
+ b_zq = exp(b_zn[None, :] - b_z)
475
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
476
+ for i_j in range(0, i_i):
477
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
478
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
479
+ # [BC, BK]
480
+ b_k = tl.load(p_k, boundary_check=(0, 1))
481
+ b_kz = exp(b_k - b_zn[None, :]).to(b_k.dtype)
482
+ # [BC, BC]
483
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
484
+ # [BC, BK]
485
+ b_dq += tl.dot(b_dA, b_kz, allow_tf32=False)
486
+ b_dq *= b_zq
487
+
488
+ o_i = tl.arange(0, BC)
489
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
490
+ m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
491
+ for j in range(0, BC):
492
+ p_kj = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,))
493
+ # [BC,]
494
+ b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0)
495
+ # [BK,]
496
+ b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32)
497
+ # [BC, BK]
498
+ m_i = o_i[:, None] >= j
499
+ # [BC, BK]
500
+ b_dq += tl.where(m_i, b_dA[:, None] * exp(b_kj[None, :] - b_z), 0.)
501
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
502
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
503
+
504
+ tl.debug_barrier()
505
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
506
+ p_zn = tl.make_block_ptr(z + i_bh * T*K, (T*K,), (1,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,))
507
+ # [BK,]
508
+ b_zn = tl.load(p_zn, boundary_check=(0,))
509
+ # [BC, BK]
510
+ b_k = tl.load(p_k, boundary_check=(0, 1))
511
+ b_kz = exp(b_k - b_zn[None, :])
512
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
513
+ for i_j in range(i_i + 1, NC):
514
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
515
+ p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
516
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0))
517
+ # [BC, BK]
518
+ b_q = tl.load(p_q, boundary_check=(0, 1))
519
+ b_z = tl.load(p_z, boundary_check=(0, 1))
520
+ b_qz = (b_q * exp(b_zn[None, :] - b_z)).to(b_q.dtype)
521
+ # [BC, BC]
522
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
523
+ # [BC, BK]
524
+ b_dk += tl.dot(tl.trans(b_dA), b_qz, allow_tf32=False)
525
+ b_dk *= b_kz
526
+
527
+ o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC)
528
+ for j in range(0, BC):
529
+ p_qj = tl.make_block_ptr(q + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
530
+ p_zj = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,))
531
+ # [BC,]
532
+ b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0)
533
+ # [BK,]
534
+ b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32)
535
+ b_zj = tl.load(p_zj, boundary_check=(0,)).to(tl.float32)
536
+ # [BC, BK]
537
+ m_i = o_i[:, None] <= j
538
+ b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_k - b_zj[None, :]), 0.)
539
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
540
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
541
+
542
+
543
+ @triton.jit(do_not_specialize=['T'])
544
+ def chunk_abc_bwd_kernel_intra_K(
545
+ v,
546
+ z,
547
+ do,
548
+ dA,
549
+ scale,
550
+ T,
551
+ V: tl.constexpr,
552
+ BT: tl.constexpr,
553
+ BC: tl.constexpr,
554
+ BV: tl.constexpr,
555
+ NC: tl.constexpr
556
+ ):
557
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
558
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
559
+ n_bh = tl.num_programs(2)
560
+
561
+ if i_i > i_j:
562
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
563
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
564
+ p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,))
565
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
566
+ p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
567
+ # [BV,]
568
+ b_zn = tl.load(p_zn, boundary_check=(0,))
569
+ # [BC, BV]
570
+ b_z = tl.load(p_z, boundary_check=(0, 1))
571
+ b_do = tl.load(p_do, boundary_check=(0, 1))
572
+ b_do = (b_do * exp(b_zn[None, :] - b_z) * scale).to(b_do.dtype)
573
+ # [BV, BC]
574
+ b_v = tl.load(p_v, boundary_check=(0, 1))
575
+ b_v = exp(b_v - b_zn[:, None]).to(b_v.dtype)
576
+ # [BC, BC]
577
+ b_dA = tl.dot(b_do, b_v, allow_tf32=False)
578
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
579
+ elif i_i == i_j:
580
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,))
581
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
582
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
583
+ # [BC, BV]
584
+ b_z = tl.load(p_z, boundary_check=(0, 1))
585
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
586
+
587
+ o_i = tl.arange(0, BC)
588
+ o_A = (i_bh + i_v * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
589
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
590
+ for j in range(0, BC):
591
+ # [BV,]
592
+ b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32)
593
+ # [BC,]
594
+ b_dA = tl.sum(b_do * exp(b_v[None, :] - b_z), 1)
595
+ b_dA = tl.where(o_i >= j, b_dA, 0)
596
+ tl.store(dA + o_A + j, b_dA.to(b_do.dtype), mask=m_A)
597
+
598
+ p_v = tl.advance(p_v, (V,))
599
+
600
+
601
+ @triton.jit(do_not_specialize=['T'])
602
+ def chunk_abc_bwd_kernel_K(
603
+ q,
604
+ k,
605
+ v,
606
+ z,
607
+ h,
608
+ A,
609
+ do,
610
+ dh,
611
+ dq,
612
+ dk,
613
+ dv,
614
+ dA,
615
+ scale,
616
+ T,
617
+ K: tl.constexpr,
618
+ V: tl.constexpr,
619
+ BT: tl.constexpr,
620
+ BK: tl.constexpr,
621
+ BV: tl.constexpr,
622
+ NT: tl.constexpr
623
+ ):
624
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
625
+ i_p = tl.maximum(i_t * BT - 1, 0)
626
+ n_bh = tl.num_programs(2)
627
+
628
+ o_i = tl.arange(0, BT)
629
+ m_s = o_i[:, None] >= o_i[None, :]
630
+
631
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
632
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
633
+ p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
634
+
635
+ # [BT, BK]
636
+ b_q = tl.load(p_q, boundary_check=(0, 1))
637
+ b_k = tl.load(p_k, boundary_check=(0, 1))
638
+ # [BT, BT]
639
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False)
640
+ b_A = tl.where(m_s, b_A, 0.)
641
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
642
+
643
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
644
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
645
+ for i_v in range(tl.cdiv(V, BV)):
646
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
647
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
648
+ p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,))
649
+ p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))
650
+ p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
651
+
652
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
653
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
654
+ p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
655
+
656
+ # [BV,]
657
+ b_zp = tl.load(p_zp, boundary_check=(0,))
658
+ b_zc = tl.load(p_zc, boundary_check=(0,))
659
+ # [BT, BV]
660
+ b_v = tl.load(p_v, boundary_check=(0, 1))
661
+ b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype)
662
+ b_z = tl.load(p_z, boundary_check=(0, 1))
663
+ b_z = exp(b_zp[None, :] - b_z)
664
+ # [BV, BK]
665
+ b_h = tl.load(p_h, boundary_check=(0, 1))
666
+ # [BT, BV]
667
+ b_do = tl.load(p_do, boundary_check=(0, 1))
668
+ b_do = (b_do * b_z * scale).to(b_do.dtype)
669
+ # [BK, BV]
670
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
671
+
672
+ # [BT, BK]
673
+ b_dq += tl.dot(b_do, b_h, allow_tf32=False)
674
+ b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)
675
+ # [BT, BV]
676
+ b_dv = b_v * tl.dot(b_k, b_dh, allow_tf32=False)
677
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
678
+ p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
679
+ # [BT, BT]
680
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
681
+ # [BT, BK]
682
+ b_dq += tl.dot(b_dA, b_k, allow_tf32=False)
683
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False)
684
+
685
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
686
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
687
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
688
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
689
+
690
+
691
+ @triton.jit(do_not_specialize=['T'])
692
+ def chunk_abc_bwd_kernel_intra_KV(
693
+ v,
694
+ z,
695
+ A,
696
+ do,
697
+ dv,
698
+ T,
699
+ V: tl.constexpr,
700
+ BT: tl.constexpr,
701
+ BC: tl.constexpr,
702
+ BV: tl.constexpr,
703
+ NC: tl.constexpr
704
+ ):
705
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
706
+ i_t, i_i = i_c // NC, i_c % NC
707
+
708
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
709
+ p_zn = tl.make_block_ptr(z + i_bh * T*V, (T*V,), (1,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,))
710
+ # [BV,]
711
+ b_zn = tl.load(p_zn, boundary_check=(0,))
712
+ # [BC, BV]
713
+ b_v = tl.load(p_v, boundary_check=(0, 1))
714
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
715
+ for i_j in range(i_i + 1, NC):
716
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
717
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1))
718
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
719
+ # [BC, BV]
720
+ b_z = tl.load(p_z, boundary_check=(0, 1))
721
+ b_do = tl.load(p_do, boundary_check=(0, 1))
722
+ b_do = (b_do * exp(b_zn[None, :] - b_z)).to(b_do.dtype)
723
+ # [BC, BC]
724
+ b_A = tl.load(p_A, boundary_check=(0, 1))
725
+ b_dv += tl.dot(b_A, b_do, allow_tf32=False)
726
+ b_dv *= exp(b_v - b_zn[None, :])
727
+
728
+ o_i = tl.arange(0, BC)
729
+ for j in range(0, BC):
730
+ p_z = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
731
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,))
732
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,))
733
+ # [BC,]
734
+ b_A = tl.load(p_A, boundary_check=(0,))
735
+ # [BV,]
736
+ b_z = tl.load(p_z, boundary_check=(0,))
737
+ b_do = tl.load(p_do, boundary_check=(0,))
738
+ # [BC, BV]
739
+ m_i = o_i[:, None] <= j
740
+ b_dv += tl.where(m_i, exp(b_v - b_z[None, :]) * b_A[:, None] * b_do[None, :], 0.)
741
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
742
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
743
+
744
+
745
+ @triton.jit(do_not_specialize=['T'])
746
+ def chunk_abc_bwd_kernel_rcum_inter(
747
+ s,
748
+ z,
749
+ ss,
750
+ doo,
751
+ T,
752
+ S: tl.constexpr,
753
+ BT: tl.constexpr,
754
+ BS: tl.constexpr,
755
+ NT: tl.constexpr
756
+ ):
757
+ i_m, i_bh = tl.program_id(0), tl.program_id(1)
758
+
759
+ b_sp = tl.zeros([BS,], dtype=tl.float32)
760
+ b_zp = tl.full([BS,], float('inf'), dtype=tl.float32)
761
+ for i_t in range(NT - 1, -1, -1):
762
+ p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
763
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
764
+ p_zc = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT) * S + i_m * BS,), (BS,), (0,))
765
+ p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
766
+ p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0))
767
+ # [BS,]
768
+ b_zc = tl.load(p_zc, boundary_check=(0,))
769
+ # [BT, BS]
770
+ b_s = tl.load(p_s, boundary_check=(0, 1))
771
+ b_z = tl.load(p_z, boundary_check=(0, 1))
772
+ b_ss = tl.load(p_ss, boundary_check=(0, 1))
773
+
774
+ b_doo = exp(b_s - b_zp[None, :]) * b_sp[None, :]
775
+ tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1))
776
+ # [BS,]
777
+ b_sp = b_sp * exp(b_zc - b_zp) + tl.sum(b_ss * exp(b_zc[None, :] - b_z), 0)
778
+ b_zp = b_zc
779
+
780
+
781
+ @triton.jit(do_not_specialize=['T'])
782
+ def chunk_abc_bwd_kernel_rcum_intra(
783
+ s,
784
+ z,
785
+ ss,
786
+ doo,
787
+ T,
788
+ S: tl.constexpr,
789
+ BT: tl.constexpr,
790
+ BC: tl.constexpr,
791
+ BS: tl.constexpr,
792
+ NC: tl.constexpr
793
+ ):
794
+ i_s, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
795
+ i_t, i_i = i_c // NC, i_c % NC
796
+
797
+ o_i = tl.arange(0, BC)
798
+ m_o = tl.full([BC, BC], 1., dtype=tl.float32)
799
+
800
+ p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0))
801
+ p_zn = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + BC - 1) * S + i_s * BS,), (BS,), (0,))
802
+ p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0))
803
+ # [BC, BS]
804
+ b_s = tl.load(p_s, boundary_check=(0, 1))
805
+ # [BS,]
806
+ b_zn = tl.load(p_zn, boundary_check=(0,))
807
+
808
+ b_doo = tl.zeros([BC, BS], dtype=tl.float32)
809
+ for i_j in range(i_i + 1, NC):
810
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0))
811
+ p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0))
812
+ # [BC, BS]
813
+ b_z = tl.load(p_z, boundary_check=(0, 1))
814
+ b_ss = tl.load(p_ss, boundary_check=(0, 1))
815
+ # [BC, BS]
816
+ b_doo += b_ss * exp(b_zn[None, :] - b_z)
817
+ b_doo = exp(b_s - b_zn[None, :]) * tl.dot(m_o.to(b_s.dtype), b_doo.to(b_s.dtype), allow_tf32=False)
818
+
819
+ for j in range(0, BC):
820
+ p_z = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,))
821
+ p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,))
822
+ # [BS,]
823
+ b_z = tl.load(p_z, boundary_check=(0,))
824
+ b_ss = tl.load(p_ss, boundary_check=(0,))
825
+ # [BC, BS]
826
+ m_i = o_i[:, None] <= j
827
+ b_doo += tl.where(m_i, exp(b_s - b_z[None, :]) * b_ss[None, :], 0.)
828
+ b_doo += tl.load(p_doo, boundary_check=(0, 1))
829
+ tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1))
830
+
831
+
832
+ class ChunkABCFunction(torch.autograd.Function):
833
+
834
+ @staticmethod
835
+ @input_guard
836
+ def forward(ctx, q, k, v, s, initial_state, output_final_state):
837
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
838
+ BT, BC = 64, 16
839
+ BK = min(64, triton.next_power_of_2(K))
840
+ BV = min(64, triton.next_power_of_2(V))
841
+ BM = min(64, triton.next_power_of_2(M))
842
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
843
+ NV, NM = triton.cdiv(V, BV), triton.cdiv(M, BM)
844
+ num_warps = 4 if BK == 64 else 2
845
+ num_stages = 1
846
+
847
+ def fwd_pre(s, B, H, T, S):
848
+ # keep cummulative normalizer in fp32
849
+ z = torch.empty_like(s, dtype=torch.float)
850
+ grid = (B * H,)
851
+ logcumsumexp_fwd_kernel[grid](
852
+ s, z,
853
+ T=T, S=S
854
+ )
855
+ return z
856
+
857
+ def fwd_inner(q, k, v, z, B, H, T, K, V, BT, BK, BV, NT, normk=False, h0=None, ht=None):
858
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
859
+ h = q.new_empty(B, H, NT * K, V)
860
+ grid = (NV, NK, B * H)
861
+ chunk_abc_fwd_kernel_h[grid](
862
+ k, v, z, h, h0, ht,
863
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
864
+ NORMK=normk,
865
+ USE_INITIAL_STATE=h0 is not None,
866
+ STORE_FINAL_STATE=ht is not None,
867
+ num_warps=num_warps,
868
+ num_stages=num_stages
869
+ )
870
+ return h
871
+
872
+ final_state = None
873
+ if output_final_state:
874
+ final_state = (q.new_empty(B, H, K, M, dtype=torch.float),
875
+ q.new_empty(B, H, M, V, dtype=torch.float))
876
+
877
+ z = fwd_pre(s, B, H, T, M)
878
+ scale = K ** -0.5
879
+ hk = fwd_inner(
880
+ q=q, k=k, v=s, z=z,
881
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
882
+ normk=False,
883
+ h0=initial_state[0] if initial_state is not None else None,
884
+ ht=final_state[0] if final_state is not None else None
885
+ )
886
+ ok1 = torch.empty_like(s)
887
+ Ak = q.new_empty(B, H, T, BT)
888
+ grid = (NM, NT, B * H)
889
+ chunk_abc_fwd_kernel_K[grid](
890
+ q, k, z, hk, ok1, Ak,
891
+ scale=scale,
892
+ T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
893
+ num_warps=num_warps,
894
+ num_stages=num_stages
895
+ )
896
+ ok0 = torch.empty_like(s)
897
+ grid = (NM, NT * NC, B * H)
898
+ chunk_abc_fwd_kernel_intra_K[grid](
899
+ s, z, ok0, Ak,
900
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
901
+ num_warps=2,
902
+ num_stages=num_stages
903
+ )
904
+ ok = ok0.add_(ok1)
905
+
906
+ scale = 1.
907
+ # p is kept in fp32 for safe softmax backward
908
+ p = softmax_fwd(ok, dtype=torch.float)
909
+ qv = p.to(q.dtype)
910
+
911
+ scale = 1.
912
+ hv = fwd_inner(
913
+ q=qv, k=s, v=v, z=z,
914
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
915
+ normk=True,
916
+ h0=initial_state[1] if initial_state is not None else None,
917
+ ht=final_state[1] if final_state is not None else None
918
+ )
919
+ Av = q.new_zeros(NM, B, H, T, BT)
920
+ grid = (NM, NT * NC * NC, B * H)
921
+ chunk_abc_fwd_kernel_intra_V[grid](
922
+ qv, s, z, Av,
923
+ scale=scale,
924
+ T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC,
925
+ num_warps=2,
926
+ num_stages=num_stages
927
+ )
928
+ Av = Av.sum(0)
929
+ ov = torch.empty_like(v)
930
+ grid = (NV, NT, B * H)
931
+ chunk_abc_fwd_kernel_V[grid](
932
+ qv, v, z, hv, ov, Av,
933
+ scale=scale,
934
+ T=T,
935
+ K=M,
936
+ V=V,
937
+ BT=BT,
938
+ BK=BM,
939
+ BV=BV,
940
+ NT=NT,
941
+ num_warps=num_warps,
942
+ num_stages=num_stages
943
+ )
944
+ ctx.save_for_backward(q, k, v, s, z, ok, p, hk, hv, Av)
945
+ ctx.BT = BT
946
+ return ov, final_state
947
+
948
+ @staticmethod
949
+ @input_guard
950
+ def backward(ctx, dov, dht=None):
951
+ q, k, v, s, z, ok, p, hk, hv, Av = ctx.saved_tensors
952
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
953
+ BT, BC = ctx.BT, 16
954
+ BK = min(64, triton.next_power_of_2(K))
955
+ BV = min(64, triton.next_power_of_2(V))
956
+ BM = min(64, triton.next_power_of_2(M))
957
+ NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC)
958
+ NK, NM = triton.cdiv(K, BK), triton.cdiv(M, BM)
959
+ num_warps = 4 if BK == 64 else 2
960
+ num_stages = 1
961
+
962
+ def bwd_inner(q, z, do, B, H, T, K, V, BT, BK, BV, NT, scale, normk=False):
963
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
964
+ dh = q.new_empty(B, H, NT * K, V)
965
+ grid = (NK, NV, B * H)
966
+ chunk_abc_bwd_kernel_dh[grid](
967
+ q, z, do, dh,
968
+ scale=scale,
969
+ T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,
970
+ NORMK=normk,
971
+ num_warps=num_warps,
972
+ num_stages=num_stages
973
+ )
974
+ return dh
975
+
976
+ def bwd_post(s, z, ss, B, H, T, S, BT, BC, BS, NT, NC, NS):
977
+ doo = torch.empty_like(s)
978
+ grid = (NS, B * H)
979
+ chunk_abc_bwd_kernel_rcum_inter[grid](
980
+ s, z, ss, doo,
981
+ T=T, S=S, BT=BT, BS=BS, NT=NT,
982
+ num_warps=num_warps,
983
+ num_stages=num_stages
984
+ )
985
+ grid = (NS, NT * NC, B * H)
986
+ chunk_abc_bwd_kernel_rcum_intra[grid](
987
+ s, z, ss, doo,
988
+ T=T, S=S, BT=BT, BC=BC, BS=BS, NC=NC,
989
+ num_warps=num_warps,
990
+ num_stages=num_stages
991
+ )
992
+ return doo
993
+
994
+ scale = 1.
995
+ qv = p.to(q.dtype)
996
+ dhv = bwd_inner(
997
+ qv, z, dov,
998
+ B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
999
+ scale=scale,
1000
+ normk=True
1001
+ )
1002
+ dp1 = torch.empty_like(p)
1003
+ dsv1 = torch.empty_like(s, dtype=torch.float)
1004
+ dv = v.new_empty(NM, *v.shape)
1005
+ dAv = q.new_zeros(B, H, T, BT)
1006
+ grid = (NM, NT, B * H)
1007
+ chunk_abc_bwd_kernel_V[grid](
1008
+ s, v, z, hv, Av, dov, dhv, dp1, dsv1, dv, dAv,
1009
+ scale=scale,
1010
+ T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT,
1011
+ num_warps=num_warps,
1012
+ num_stages=num_stages
1013
+ )
1014
+ dv = dv.sum(0)
1015
+ dp0 = torch.empty_like(p)
1016
+ dsv0 = s.new_zeros(s.shape, dtype=torch.float)
1017
+ grid = (NM, NT * NC, B * H)
1018
+ chunk_abc_bwd_kernel_intra_V[grid](
1019
+ qv, s, z, dAv, dp0, dsv0,
1020
+ T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC,
1021
+ num_warps=2,
1022
+ num_stages=num_stages
1023
+ )
1024
+ dp = dp1.add_(dp0)
1025
+ dsv = dsv1.add_(dsv0)
1026
+
1027
+ # softmax gradient, equivalent to:
1028
+ # dok = p * (dp - (p * dp).sum(-1, True))
1029
+ dok = softmax_bwd(p, dp, dtype=ok.dtype)
1030
+
1031
+ scale = K ** -0.5
1032
+ dhk = bwd_inner(
1033
+ q, z, dok,
1034
+ B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
1035
+ scale=scale,
1036
+ normk=False
1037
+ )
1038
+ dAk = q.new_zeros(NM, B, H, T, BT)
1039
+ grid = (NM, NT * NC * NC, B * H)
1040
+ chunk_abc_bwd_kernel_intra_K[grid](
1041
+ s, z, dok, dAk,
1042
+ scale=scale,
1043
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
1044
+ num_warps=2,
1045
+ num_stages=num_stages
1046
+ )
1047
+ dAk = dAk.sum(0)
1048
+
1049
+ Ak = q.new_zeros(NK, B, H, T, BT)
1050
+ dq = torch.empty_like(q)
1051
+ dk = torch.empty_like(k)
1052
+ dsk1 = s.new_empty(NK, *s.shape, dtype=torch.float)
1053
+ grid = (NK, NT, B * H)
1054
+ chunk_abc_bwd_kernel_K[grid](
1055
+ q, k, s, z, hk, Ak, dok, dhk, dq, dk, dsk1, dAk,
1056
+ scale=scale,
1057
+ T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT,
1058
+ num_warps=num_warps,
1059
+ num_stages=num_stages
1060
+ )
1061
+ Ak = Ak.sum(0)
1062
+ dsk1 = dsk1.sum(0)
1063
+ dsk0 = torch.empty_like(s, dtype=torch.float)
1064
+ grid = (NM, NT * NC, B * H)
1065
+ chunk_abc_bwd_kernel_intra_KV[grid](
1066
+ s, z, Ak, dok, dsk0,
1067
+ T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC,
1068
+ num_warps=2,
1069
+ num_stages=num_stages
1070
+ )
1071
+ ds = dsv.add_(dsk1.add_(dsk0))
1072
+ ds -= bwd_post(s, z, ok * dok + p * dp, B, H, T, M, BT, BC, BM, NT, NC, NM)
1073
+ ds = ds.to(s.dtype)
1074
+ return dq, dk, dv, ds, None, None
1075
+
1076
+
1077
+ @torch.compiler.disable
1078
+ def chunk_abc(
1079
+ q: torch.Tensor,
1080
+ k: torch.Tensor,
1081
+ v: torch.Tensor,
1082
+ s: torch.Tensor,
1083
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1084
+ output_final_state: bool = False,
1085
+ head_first: bool = True
1086
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1087
+ r"""
1088
+ Args:
1089
+ q (torch.Tensor):
1090
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
1091
+ k (torch.Tensor):
1092
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
1093
+ v (torch.Tensor):
1094
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
1095
+ s (torch.Tensor):
1096
+ slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`
1097
+ initial_state (Optional[Tuple[torch.Tensor, torch.Tensor]]):
1098
+ Initial states of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `None`.
1099
+ output_final_state (Optional[bool]):
1100
+ Whether to output the final state of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `False`.
1101
+ head_first (Optional[bool]):
1102
+ Whether the inputs are in the head-first format.
1103
+ Default: `True`.
1104
+
1105
+ Returns:
1106
+ o (torch.Tensor):
1107
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1108
+ final_state (torch.Tensor):
1109
+ Final state of shape `[B, H, K, M]` and `[B, H, M, V]` if `output_final_state=True` else `None`.
1110
+ """
1111
+ if not head_first:
1112
+ q, k, v, s = map(lambda x: x.transpose(1, 2), (q, k, v, s))
1113
+ o, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state)
1114
+ if not head_first:
1115
+ o = o.transpose(1, 2)
1116
+ return o, final_state
fla/ops/based/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .fused_chunk import fused_chunk_based
4
+ from .parallel import parallel_based
5
+
6
+ __all__ = [
7
+ 'fused_chunk_based',
8
+ 'parallel_based'
9
+ ]
fla/ops/common/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
fla/ops/common/chunk_delta_h.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper, use_cuda_graph
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_G': lambda args: args['g'] is not None,
19
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
20
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
21
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
22
+ })
23
+ @triton.autotune(
24
+ configs=[
25
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
26
+ for num_warps in NUM_WARPS
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
30
+ use_cuda_graph=use_cuda_graph,
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_gated_delta_rule_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ d,
37
+ v_new,
38
+ g,
39
+ h,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ chunk_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BC: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ NT: tl.constexpr,
53
+ USE_G: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr,
58
+ ):
59
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+ i_n, i_h = i_nh // H, i_nh % H
61
+ if USE_OFFSETS:
62
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
63
+ T = eos - bos
64
+ NT = tl.cdiv(T, BT)
65
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
66
+ else:
67
+ bos, eos = i_n * T, i_n * T + T
68
+ NT = tl.cdiv(T, BT)
69
+ boh = i_n * NT
70
+
71
+ # [BK, BV]
72
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
73
+ if USE_INITIAL_STATE:
74
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
75
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
76
+
77
+ for i_t in range(NT):
78
+ if HEAD_FIRST:
79
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ else:
81
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
83
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
84
+ if USE_G:
85
+ last_idx = min((i_t + 1) * BT, T) - 1
86
+ if HEAD_FIRST:
87
+ b_g_last = tl.load(g + i_nh * T + last_idx)
88
+ else:
89
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
90
+ else:
91
+ b_g_last = None
92
+ last_idx = None
93
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
94
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
95
+ if HEAD_FIRST:
96
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
97
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
98
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
99
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
100
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
101
+ else:
102
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
103
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
104
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
105
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
106
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT+i_c*BC, ), (BC,), (0,)) if USE_G else None
107
+ b_g = tl.load(p_g, boundary_check=(0, )) if USE_G else None
108
+ # [BK, BC]
109
+ b_k = tl.load(p_k, boundary_check=(0, 1))
110
+ b_k = (b_k * exp(b_g_last - b_g)[None, :]).to(b_k.dtype) if USE_G else b_k
111
+ # [BC, BK]
112
+ b_d = tl.load(p_d, boundary_check=(0, 1))
113
+ b_d = (b_d * exp(b_g)[:, None]).to(b_d.dtype) if USE_G else b_d
114
+ # [BC, BV]
115
+ b_v = tl.load(p_v, boundary_check=(0, 1))
116
+ b_v2 = b_v - tl.dot(b_d, b_h.to(b_d.dtype))
117
+ # [BK, BV]
118
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
119
+ b_hc += tl.dot(b_k, b_v2.to(b_k.dtype), allow_tf32=False)
120
+ b_h *= exp(b_g_last) if USE_G else 1
121
+ b_h += b_hc
122
+
123
+ if STORE_FINAL_STATE:
124
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
125
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
126
+
127
+
128
+ @triton.heuristics({
129
+ 'USE_G': lambda args: args['g'] is not None,
130
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
131
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in NUM_WARPS
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BT', 'BK', 'BV', 'USE_G'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_gated_delta_rule_bwd_kernel_dhu(
145
+ q,
146
+ k,
147
+ d,
148
+ g,
149
+ dht,
150
+ dh0,
151
+ do,
152
+ dh,
153
+ dv,
154
+ dv2,
155
+ offsets,
156
+ chunk_offsets,
157
+ scale,
158
+ T,
159
+ H: tl.constexpr,
160
+ K: tl.constexpr,
161
+ V: tl.constexpr,
162
+ BT: tl.constexpr,
163
+ BC: tl.constexpr,
164
+ BK: tl.constexpr,
165
+ BV: tl.constexpr,
166
+ USE_G: tl.constexpr,
167
+ USE_INITIAL_STATE: tl.constexpr,
168
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr
171
+ ):
172
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
173
+ i_n, i_h = i_nh // H, i_nh % H
174
+ if USE_OFFSETS:
175
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
176
+ T = eos - bos
177
+ NT = tl.cdiv(T, BT)
178
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
179
+ else:
180
+ bos, eos = i_n * T, i_n * T + T
181
+ NT = tl.cdiv(T, BT)
182
+ boh = i_n * NT
183
+
184
+ # [BK, BV]
185
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
186
+ if USE_FINAL_STATE_GRADIENT:
187
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
188
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
189
+
190
+ for i_t in range(NT - 1, -1, -1):
191
+ if HEAD_FIRST:
192
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
193
+ else:
194
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
195
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
196
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
197
+ if USE_G:
198
+ last_idx = min((i_t + 1) * BT, T) - 1
199
+ if HEAD_FIRST:
200
+ bg_last = tl.load(g + i_nh * T + last_idx)
201
+ else:
202
+ bg_last = tl.load(g + (bos + last_idx) * H + i_h)
203
+ else:
204
+ bg_last = None
205
+ last_idx = None
206
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
207
+ if HEAD_FIRST:
208
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
209
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
210
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
211
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
212
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
213
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
214
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
215
+ else:
216
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
217
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
218
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
219
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
220
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
221
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT + i_c * BC,), (BC,), (0,)) if USE_G else None
222
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
223
+ b_g = tl.load(p_g, boundary_check=(0,)) if USE_G else None
224
+ # [BK, BT]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale * exp(b_g)[None, :]).to(b_q.dtype) if USE_G else (b_q * scale).to(b_q.dtype)
227
+ # [BT, BK]
228
+ b_k = tl.load(p_k, boundary_check=(0, 1))
229
+ b_d = tl.load(p_d, boundary_check=(0, 1))
230
+ b_k = (b_k * exp(bg_last - b_g)[:, None]).to(b_k.dtype) if USE_G else b_k
231
+ b_d = (b_d * exp(b_g)[None, :]).to(b_d.dtype) if USE_G else b_d
232
+ # [BT, V]
233
+ b_do = tl.load(p_do, boundary_check=(0, 1))
234
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
235
+ b_dv2 = b_dv + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
236
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
237
+ # [BK, BV]
238
+ b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
239
+ b_dh_tmp -= tl.dot(b_d, b_dv2.to(b_q.dtype), allow_tf32=False)
240
+ b_dh *= exp(bg_last) if USE_G else 1
241
+ b_dh += b_dh_tmp
242
+
243
+ if USE_INITIAL_STATE:
244
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
245
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
246
+
247
+
248
+ def chunk_gated_delta_rule_fwd_h(
249
+ k: torch.Tensor,
250
+ w: torch.Tensor,
251
+ u: torch.Tensor,
252
+ g: Optional[torch.Tensor] = None,
253
+ initial_state: Optional[torch.Tensor] = None,
254
+ output_final_state: bool = False,
255
+ offsets: Optional[torch.LongTensor] = None,
256
+ indices: Optional[torch.LongTensor] = None,
257
+ head_first: bool = True,
258
+ chunk_size: int = 64
259
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
260
+ if head_first:
261
+ B, H, T, K, V = *k.shape, u.shape[-1]
262
+ else:
263
+ B, T, H, K, V = *k.shape, u.shape[-1]
264
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
265
+ # N: the actual number of sequences in the batch with either equal or variable lengths
266
+ if offsets is None:
267
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
268
+ else:
269
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
270
+ BK = triton.next_power_of_2(K)
271
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
272
+ # H100 can have larger block size
273
+ if check_shared_mem('hopper', k.device.index):
274
+ BV = 64
275
+ BC = 64 if K <= 128 else 32
276
+ # A100
277
+ elif check_shared_mem('ampere', k.device.index):
278
+ BV = 32
279
+ BC = 64
280
+ else:
281
+ BV = 32
282
+ BC = 32 if K <= 128 else 16
283
+ BC = min(BT, BC)
284
+ NK = triton.cdiv(K, BK)
285
+ NV = triton.cdiv(V, BV)
286
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
287
+
288
+ if head_first:
289
+ h = k.new_empty(B, H, NT, K, V)
290
+ else:
291
+ h = k.new_empty(B, NT, H, K, V)
292
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
293
+
294
+ v_new = torch.empty_like(u)
295
+ grid = (NK, NV, N * H)
296
+
297
+ chunk_gated_delta_rule_fwd_kernel_h[grid](
298
+ k=k,
299
+ v=u,
300
+ d=w,
301
+ v_new=v_new,
302
+ g=g,
303
+ h=h,
304
+ h0=initial_state,
305
+ ht=final_state,
306
+ offsets=offsets,
307
+ chunk_offsets=chunk_offsets,
308
+ T=T,
309
+ H=H,
310
+ K=K,
311
+ V=V,
312
+ BT=BT,
313
+ BC=BC,
314
+ BK=BK,
315
+ BV=BV,
316
+ NT=NT,
317
+ HEAD_FIRST=head_first
318
+ )
319
+ return h, v_new, final_state
320
+
321
+
322
+ def chunk_gated_delta_rule_bwd_dhu(
323
+ q: torch.Tensor,
324
+ k: torch.Tensor,
325
+ w: torch.Tensor,
326
+ g: torch.Tensor,
327
+ h0: torch.Tensor,
328
+ dht: Optional[torch.Tensor],
329
+ do: torch.Tensor,
330
+ dv: torch.Tensor,
331
+ scale: float,
332
+ offsets: Optional[torch.LongTensor] = None,
333
+ indices: Optional[torch.LongTensor] = None,
334
+ head_first: bool = True,
335
+ chunk_size: int = 64
336
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
337
+ if head_first:
338
+ B, H, T, K, V = *q.shape, do.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *q.shape, do.shape[-1]
341
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
342
+ # N: the actual number of sequences in the batch with either equal or variable lengths
343
+ if offsets is None:
344
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
345
+ else:
346
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
347
+
348
+ BK = triton.next_power_of_2(K)
349
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
350
+
351
+ # H100
352
+ if check_shared_mem('hopper', q.device.index):
353
+ BV = 64
354
+ BC = 64 if K <= 128 else 32
355
+ # A100
356
+ elif check_shared_mem('ampere', q.device.index):
357
+ BV = 32
358
+ BC = 64 if K <= 128 else 32
359
+ else:
360
+ BV = 32 if K <= 128 else 16
361
+ BC = 16
362
+
363
+ BC = min(BT, BC)
364
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
365
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
366
+
367
+ if head_first:
368
+ dh = q.new_empty(B, H, NT, K, V)
369
+ else:
370
+ dh = q.new_empty(B, NT, H, K, V)
371
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
372
+ dv2 = torch.empty_like(dv)
373
+
374
+ grid = (NK, NV, N * H)
375
+ chunk_gated_delta_rule_bwd_kernel_dhu[grid](
376
+ q=q,
377
+ k=k,
378
+ d=w,
379
+ g=g,
380
+ dht=dht,
381
+ dh0=dh0,
382
+ do=do,
383
+ dh=dh,
384
+ dv=dv,
385
+ dv2=dv2,
386
+ offsets=offsets,
387
+ chunk_offsets=chunk_offsets,
388
+ scale=scale,
389
+ T=T,
390
+ H=H,
391
+ K=K,
392
+ V=V,
393
+ BT=BT,
394
+ BC=BC,
395
+ BK=BK,
396
+ BV=BV,
397
+ HEAD_FIRST=head_first
398
+ )
399
+ return dh, dh0, dv2
fla/ops/common/chunk_h.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem
13
+
14
+ BKV_LIST = [32, 64] if check_shared_mem() else [16, 32]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BK in BKV_LIST
26
+ for BV in BKV_LIST
27
+ for num_warps in [1, 2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ gk,
39
+ gv,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ split_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BS: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ USE_G: tl.constexpr,
53
+ USE_GK: tl.constexpr,
54
+ USE_GV: tl.constexpr,
55
+ USE_INITIAL_STATE: tl.constexpr,
56
+ STORE_FINAL_STATE: tl.constexpr,
57
+ USE_OFFSETS: tl.constexpr,
58
+ HEAD_FIRST: tl.constexpr
59
+ ):
60
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
61
+ i_n, i_h = i_nh // H, i_nh % H
62
+ if USE_OFFSETS:
63
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
64
+ T = eos - bos
65
+ NT = tl.cdiv(T, BT)
66
+ NS = tl.cdiv(T, BS)
67
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
68
+ else:
69
+ bos, eos = i_n * T, i_n * T + T
70
+ NT = tl.cdiv(T, BT)
71
+ NS = tl.cdiv(T, BS)
72
+ boh = i_n * NS
73
+
74
+ # [BK, BV]
75
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
76
+ if USE_INITIAL_STATE:
77
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
79
+
80
+ for i_t in range(NT):
81
+ i_s = i_t // (BS // BT)
82
+ if HEAD_FIRST:
83
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
84
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
85
+
86
+ o_h = (i_nh * NS + i_s).to(tl.int64) * K*V
87
+ p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
88
+ else:
89
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
90
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
91
+
92
+ o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
93
+ p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
94
+
95
+ if i_t % (BS // BT) == 0:
96
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
97
+ # [BK, BT]
98
+ b_k = tl.load(p_k, boundary_check=(0, 1))
99
+ # [BT, BV]
100
+ b_v = tl.load(p_v, boundary_check=(0, 1))
101
+ last_idx = min((i_t + 1) * BT, T) - 1
102
+
103
+ # scalar decay
104
+ if USE_G:
105
+ if HEAD_FIRST:
106
+ b_g_last = tl.load(g + i_nh * T + last_idx)
107
+ p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT)
108
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
109
+ else:
110
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
111
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
112
+ b_h *= exp(b_g_last)
113
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
114
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
115
+
116
+ # vector decay, h = Diag(gk) @ h
117
+ if USE_GK:
118
+ if HEAD_FIRST:
119
+ p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
120
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
121
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
122
+ else:
123
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
124
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
125
+
126
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
127
+ b_h *= exp(b_gk_last)[:, None]
128
+
129
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
130
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
131
+
132
+ # vector decay, h = h @ Diag(gv)
133
+ if USE_GV:
134
+ if HEAD_FIRST:
135
+ p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
136
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
137
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
138
+ else:
139
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
140
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
141
+
142
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
143
+ b_h *= exp(b_gv_last)[None, :]
144
+
145
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
146
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
147
+
148
+ b_h += tl.dot(b_k, b_v)
149
+
150
+ if STORE_FINAL_STATE:
151
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
152
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
153
+
154
+
155
+ @triton.heuristics({
156
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
157
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
158
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
159
+ })
160
+ @triton.autotune(
161
+ configs=[
162
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
163
+ for BK in BKV_LIST
164
+ for BV in BKV_LIST
165
+ for num_warps in [1, 2, 4, 8]
166
+ for num_stages in [2, 3, 4]
167
+ ],
168
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
169
+ )
170
+ @triton.jit(do_not_specialize=['T'])
171
+ def chunk_bwd_kernel_dh(
172
+ q,
173
+ g,
174
+ gk,
175
+ gv,
176
+ do,
177
+ dh,
178
+ dht,
179
+ dh0,
180
+ offsets,
181
+ split_offsets,
182
+ scale,
183
+ T,
184
+ HQ: tl.constexpr,
185
+ H: tl.constexpr,
186
+ K: tl.constexpr,
187
+ V: tl.constexpr,
188
+ BT: tl.constexpr,
189
+ BS: tl.constexpr,
190
+ BK: tl.constexpr,
191
+ BV: tl.constexpr,
192
+ NG: tl.constexpr,
193
+ USE_G: tl.constexpr,
194
+ USE_GK: tl.constexpr,
195
+ USE_GV: tl.constexpr,
196
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
197
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
198
+ USE_OFFSETS: tl.constexpr,
199
+ HEAD_FIRST: tl.constexpr
200
+ ):
201
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
202
+ i_bg = i_nh // NG
203
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
204
+ i_h = i_hq // NG
205
+ if USE_OFFSETS:
206
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
207
+ T = eos - bos
208
+ NT = tl.cdiv(T, BT)
209
+ NS = tl.cdiv(T, BS)
210
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
211
+ else:
212
+ bos, eos = i_n * T, i_n * T + T
213
+ NT = tl.cdiv(T, BT)
214
+ NS = tl.cdiv(T, BS)
215
+ boh = i_n * NS
216
+
217
+ # [BK, BV]
218
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
219
+ if USE_FINAL_STATE_GRADIENT:
220
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
221
+ b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
222
+
223
+ for i_t in range(NT - 1, -1, -1):
224
+ i_s = i_t // (BS // BT)
225
+ if HEAD_FIRST:
226
+ o_dh = (i_nh * NS + i_s).to(tl.int64) * K*V
227
+ p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
228
+ else:
229
+ o_dh = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
230
+ p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
231
+
232
+ if i_t % (BS // BT) == 0:
233
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
234
+ last_idx = min(i_t * BT + BT, T) - 1
235
+ # [BK, BT]
236
+ if HEAD_FIRST:
237
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
238
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
239
+ else:
240
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
241
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
242
+ b_q = tl.load(p_q, boundary_check=(0, 1))
243
+ b_q = (b_q * scale).to(b_q.dtype)
244
+ # [BT, BV]
245
+ b_do = tl.load(p_do, boundary_check=(0, 1))
246
+
247
+ if USE_G:
248
+ if HEAD_FIRST:
249
+ p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
250
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
251
+ b_g_last = tl.load(g + i_bg * T + last_idx)
252
+ else:
253
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
254
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
255
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
256
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
257
+
258
+ b_dh *= exp(b_g_last)
259
+
260
+ if USE_GK:
261
+ if HEAD_FIRST:
262
+ p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
263
+ p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
264
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
265
+ else:
266
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
267
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
268
+
269
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
270
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
271
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
272
+ b_dh *= exp(b_gk_last)[:, None]
273
+
274
+ if USE_GV:
275
+ if HEAD_FIRST:
276
+ p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
277
+ p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
278
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
279
+ else:
280
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
281
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
282
+
283
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
284
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
285
+
286
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
287
+ b_dh *= exp(b_gv_last)[None, :]
288
+
289
+ b_dh += tl.dot(b_q, b_do)
290
+
291
+ if STORE_INITIAL_STATE_GRADIENT:
292
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
293
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
294
+
295
+
296
+ def chunk_fwd_h(
297
+ k: torch.Tensor,
298
+ v: torch.Tensor,
299
+ g: torch.Tensor,
300
+ gk: torch.Tensor,
301
+ gv: torch.Tensor,
302
+ h0: torch.Tensor,
303
+ output_final_state: bool,
304
+ offsets: Optional[torch.Tensor] = None,
305
+ head_first: bool = True,
306
+ chunk_size: int = 64,
307
+ split_size: Optional[int] = None,
308
+ states_in_fp32: bool = False
309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ if head_first:
311
+ B, H, T, K, V = *k.shape, v.shape[-1]
312
+ else:
313
+ B, T, H, K, V = *k.shape, v.shape[-1]
314
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
315
+ BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
316
+ assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
317
+ # N: the actual number of sequences in the batch with either equal or variable lengths
318
+ if offsets is None:
319
+ split_offsets, N, NS = None, B, triton.cdiv(T, BS)
320
+ else:
321
+ split_offsets = prepare_chunk_offsets(offsets, BS)
322
+ N, NS = len(offsets) - 1, split_offsets[-1]
323
+
324
+ if head_first:
325
+ h = k.new_empty(B, H, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
326
+ else:
327
+ h = k.new_empty(B, NS, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
328
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
329
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
330
+ chunk_fwd_kernel_h[grid](
331
+ k=k,
332
+ v=v,
333
+ h=h,
334
+ g=g,
335
+ gk=gk,
336
+ gv=gv,
337
+ h0=h0,
338
+ ht=ht,
339
+ offsets=offsets,
340
+ split_offsets=split_offsets,
341
+ T=T,
342
+ H=H,
343
+ K=K,
344
+ V=V,
345
+ BT=BT,
346
+ BS=BS,
347
+ USE_G=g is not None,
348
+ USE_GK=gk is not None,
349
+ USE_GV=gv is not None,
350
+ HEAD_FIRST=head_first
351
+ )
352
+ return h, ht
353
+
354
+
355
+ def chunk_bwd_dh(
356
+ q: torch.Tensor,
357
+ k: torch.Tensor,
358
+ v: torch.Tensor,
359
+ g: torch.Tensor,
360
+ gk: torch.Tensor,
361
+ gv: torch.Tensor,
362
+ do: torch.Tensor,
363
+ h0: torch.Tensor,
364
+ dht: torch.Tensor,
365
+ scale: float,
366
+ offsets: Optional[torch.Tensor] = None,
367
+ head_first: bool = True,
368
+ chunk_size: int = 64,
369
+ split_size: Optional[int] = None,
370
+ states_in_fp32: bool = False
371
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
372
+ if head_first:
373
+ B, H, T, K, V = *k.shape, v.shape[-1]
374
+ HQ = q.shape[1]
375
+ else:
376
+ B, T, H, K, V = *k.shape, v.shape[-1]
377
+ HQ = q.shape[2]
378
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
379
+ BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
380
+ assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
381
+ # N: the actual number of sequences in the batch with either equal or variable lengths
382
+ # NG: number of groups in GQA
383
+ if offsets is None:
384
+ split_offsets, N, NS = None, B, triton.cdiv(T, BS)
385
+ else:
386
+ split_offsets = prepare_chunk_offsets(offsets, BS)
387
+ N, NS = len(offsets) - 1, split_offsets[-1]
388
+ NG = HQ // H
389
+
390
+ if head_first:
391
+ dh = k.new_empty(B, HQ, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
392
+ else:
393
+ dh = k.new_empty(B, NS, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
394
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
395
+
396
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
397
+ chunk_bwd_kernel_dh[grid](
398
+ q=q,
399
+ g=g,
400
+ gk=gk,
401
+ gv=gv,
402
+ do=do,
403
+ dh=dh,
404
+ dht=dht,
405
+ dh0=dh0,
406
+ offsets=offsets,
407
+ split_offsets=split_offsets,
408
+ scale=scale,
409
+ T=T,
410
+ HQ=HQ,
411
+ H=H,
412
+ K=K,
413
+ V=V,
414
+ BT=BT,
415
+ BS=BS,
416
+ NG=NG,
417
+ USE_G=g is not None,
418
+ USE_GK=gk is not None,
419
+ USE_GV=gv is not None,
420
+ HEAD_FIRST=head_first
421
+ )
422
+ return dh, dh0
fla/ops/common/chunk_o.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, safe_exp
11
+ from fla.utils import check_shared_mem, is_nvidia_hopper
12
+
13
+ BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_G': lambda args: args['g'] is not None,
19
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
20
+ })
21
+ @triton.autotune(
22
+ configs=[
23
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
24
+ for BK in BKV_LIST
25
+ for BV in BKV_LIST
26
+ for num_warps in NUM_WARPS
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['H', 'K', 'V', 'BT'],
30
+ )
31
+ @triton.jit(do_not_specialize=['T'])
32
+ def chunk_fwd_kernel_o(
33
+ q,
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ o,
39
+ offsets,
40
+ indices,
41
+ scale,
42
+ T,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ USE_G: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr,
51
+ HEAD_FIRST: tl.constexpr
52
+ ):
53
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
54
+ i_b, i_h = i_bh // H, i_bh % H
55
+
56
+ if USE_OFFSETS:
57
+ i_tg = i_t
58
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
60
+ T = eos - bos
61
+ NT = tl.cdiv(T, BT)
62
+ else:
63
+ NT = tl.cdiv(T, BT)
64
+ i_tg = i_b * NT + i_t
65
+ bos, eos = i_b * T, i_b * T + T
66
+
67
+ s_qk = K if HEAD_FIRST else H*K
68
+ s_vo = V if HEAD_FIRST else H*V
69
+ s_g = 1 if HEAD_FIRST else H
70
+ # offset calculation
71
+ q += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
72
+ k += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K)
73
+ v += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V)
74
+ o += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V)
75
+ h += ((i_bh * NT + i_t).to(tl.int64) * K*V) if HEAD_FIRST else ((i_tg * H + i_h).to(tl.int64) * K*V)
76
+
77
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
78
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
79
+
80
+ for i_k in range(tl.cdiv(K, BK)):
81
+ p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
82
+ p_k = tl.make_block_ptr(k, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
83
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
84
+ # [BT, BK]
85
+ b_q = tl.load(p_q, boundary_check=(0, 1))
86
+ # [BK, BT]
87
+ b_k = tl.load(p_k, boundary_check=(0, 1))
88
+ # [BK, BV]
89
+ b_h = tl.load(p_h, boundary_check=(0, 1))
90
+
91
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
92
+ b_o += tl.dot(b_q, b_h)
93
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
94
+ b_A += tl.dot(b_q, b_k)
95
+
96
+ if USE_G:
97
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
98
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
99
+ b_g = tl.load(p_g, boundary_check=(0,))
100
+ b_o = b_o * exp(b_g)[:, None]
101
+ b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
102
+
103
+ o_i = tl.arange(0, BT)
104
+ m_A = o_i[:, None] >= o_i[None, :]
105
+ b_A = tl.where(m_A, b_A, 0)
106
+
107
+ p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
108
+ p_o = tl.make_block_ptr(o, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
109
+ b_v = tl.load(p_v, boundary_check=(0, 1))
110
+
111
+ # to fix mma -> mma layout conversion
112
+ # already solved by triton v3.2 or higher
113
+ b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
114
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
115
+
116
+
117
+ @triton.heuristics({
118
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
119
+ 'USE_G': lambda args: args['g'] is not None,
120
+ 'USE_DW': lambda args: args['dw'] is not None
121
+ })
122
+ @triton.autotune(
123
+ configs=[
124
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
125
+ for num_warps in NUM_WARPS
126
+ for num_stages in [2, 3, 4]
127
+ ],
128
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_DW'],
129
+ )
130
+ @triton.jit(do_not_specialize=['T'])
131
+ def chunk_bwd_kernel_dqkwg(
132
+ q,
133
+ k,
134
+ v,
135
+ h,
136
+ g,
137
+ do,
138
+ dh,
139
+ dq,
140
+ dk,
141
+ dg,
142
+ w,
143
+ dv,
144
+ dw,
145
+ offsets,
146
+ indices,
147
+ scale,
148
+ B: tl.constexpr,
149
+ T,
150
+ H: tl.constexpr,
151
+ K: tl.constexpr,
152
+ V: tl.constexpr,
153
+ BT: tl.constexpr,
154
+ BK: tl.constexpr,
155
+ BV: tl.constexpr,
156
+ USE_G: tl.constexpr,
157
+ USE_DW: tl.constexpr,
158
+ USE_OFFSETS: tl.constexpr,
159
+ HEAD_FIRST: tl.constexpr
160
+ ):
161
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
162
+ i_b, i_h = i_bh // H, i_bh % H
163
+ if USE_G:
164
+ dg += i_k * B * H * T
165
+ if USE_OFFSETS:
166
+ i_tg = i_t
167
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
168
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
169
+ T = eos - bos
170
+ NT = tl.cdiv(T, BT)
171
+ else:
172
+ NT = tl.cdiv(T, BT)
173
+ i_tg = i_b * NT + i_t
174
+ bos, eos = i_b * T, i_b * T + T
175
+
176
+ # offset calculation
177
+ v += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
178
+ do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
179
+ h += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
180
+ dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
181
+ q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
182
+ k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
183
+ dq += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
184
+ dk += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
185
+ s_qk = K if HEAD_FIRST else H*K
186
+ s_vo = V if HEAD_FIRST else H*V
187
+ s_g = 1 if HEAD_FIRST else H
188
+
189
+ # for delta rule only
190
+ if USE_DW:
191
+ dw += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
192
+ dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
193
+ w += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
194
+
195
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
196
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
197
+ b_ds = tl.zeros([BT, BT], dtype=tl.float32)
198
+ b_dg_last = tl.zeros([1,], dtype=tl.float32) if USE_G else None
199
+ b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None
200
+
201
+ for i_v in range(tl.cdiv(V, BV)):
202
+ p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
203
+ p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
204
+ p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
205
+ p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
206
+ # [BT, BV]
207
+ b_v = tl.load(p_v, boundary_check=(0, 1))
208
+ b_do = tl.load(p_do, boundary_check=(0, 1))
209
+ # [BV, BK]
210
+ b_h = tl.load(p_h, boundary_check=(0, 1))
211
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
212
+ if USE_G:
213
+ b_dg_last += (tl.sum(b_h * b_dh))
214
+ # [BT, BV] @ [BV, BT] -> [BT, BT]
215
+ b_ds += tl.dot(b_do, tl.trans(b_v))
216
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
217
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
218
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
219
+ b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
220
+ if USE_DW:
221
+ p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
222
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
223
+ b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
224
+
225
+ if USE_DW and not USE_G:
226
+ p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
227
+ tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
228
+
229
+ tl.debug_barrier()
230
+ o_i = tl.arange(0, BT)
231
+ p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
232
+ p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
233
+ b_q = tl.load(p_q, boundary_check=(0, 1))
234
+ b_k = tl.load(p_k, boundary_check=(0, 1))
235
+
236
+ p_dq = tl.make_block_ptr(dq, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
237
+ p_dk = tl.make_block_ptr(dk, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
238
+
239
+ if USE_G:
240
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
241
+ g += i_bh * T if HEAD_FIRST else bos * H + i_h
242
+ dg += i_bh * T if HEAD_FIRST else bos * H + i_h
243
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
244
+ b_g = tl.load(p_g, boundary_check=(0,))
245
+ b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g)
246
+ b_dg_last *= exp(b_g_last)
247
+
248
+ if USE_DW:
249
+ p_w = tl.make_block_ptr(w, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
250
+ p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
251
+ b_w = tl.load(p_w, boundary_check=(0, 1))
252
+ b_dw = b_dw * exp(b_g)[:, None]
253
+ tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
254
+ b_dg -= tl.sum(b_w * b_dw, axis=1)
255
+
256
+ b_dq = b_dq * exp(b_g)[:, None] * scale
257
+ b_dg += tl.sum(b_dq * b_q, axis=1)
258
+
259
+ b_dk = b_dk * safe_exp(-b_g + b_g_last)[:, None]
260
+ b_dg -= tl.sum(b_k * b_dk, axis=1)
261
+ b_dg_last += tl.sum(b_dk * b_k)
262
+
263
+ b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * safe_exp(b_g[:, None] - b_g[None, :]), 0) * scale
264
+ b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k))
265
+ b_dg += tl.sum(b_ds2, axis=1)
266
+ b_dg -= tl.sum(b_ds2, axis=0)
267
+
268
+ b_ds = b_ds.to(b_k.dtype)
269
+ # [BT, BK]
270
+ b_dq += tl.dot(b_ds, b_k)
271
+ b_dk += tl.dot(tl.trans(b_ds), b_q)
272
+ p_dg = tl.make_block_ptr(dg, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
273
+ # (SY 09/21) revcumsum in a separate kernel due to strange triton compiler issue
274
+ # b_dg = tl.dot(tl.where(o_i[:, None] <= o_i[None, :], 1., 0.), b_dg, allow_tf32=False) + b_dg_last)
275
+ b_dg = tl.where(o_i < min(BT, T-i_t*BT) - 1, b_dg, b_dg + b_dg_last)
276
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
277
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
278
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
279
+ else:
280
+ b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0)
281
+ b_ds = b_ds.to(b_k.dtype)
282
+ b_dq += tl.dot(b_ds, b_k)
283
+ b_dk += tl.dot(tl.trans(b_ds), b_q) * scale
284
+ b_dq *= scale
285
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
286
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
287
+
288
+
289
+ @triton.heuristics({
290
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
291
+ 'USE_G': lambda args: args['g'] is not None,
292
+ })
293
+ @triton.autotune(
294
+ configs=[
295
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
296
+ for num_warps in [2, 4, 8]
297
+ for num_stages in [2, 3, 4]
298
+ ],
299
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
300
+ )
301
+ @triton.jit(do_not_specialize=['T'])
302
+ def chunk_bwd_kernel_dv(
303
+ q,
304
+ k,
305
+ g,
306
+ do,
307
+ dv,
308
+ dh,
309
+ offsets,
310
+ indices,
311
+ scale,
312
+ T,
313
+ H: tl.constexpr,
314
+ K: tl.constexpr,
315
+ V: tl.constexpr,
316
+ BT: tl.constexpr,
317
+ BK: tl.constexpr,
318
+ BV: tl.constexpr,
319
+ USE_G: tl.constexpr,
320
+ USE_OFFSETS: tl.constexpr,
321
+ HEAD_FIRST: tl.constexpr
322
+ ):
323
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
324
+ i_b, i_h = i_bh // H, i_bh % H
325
+ if USE_OFFSETS:
326
+ i_tg = i_t
327
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
328
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
329
+ T = eos - bos
330
+ NT = tl.cdiv(T, BT)
331
+ else:
332
+ NT = tl.cdiv(T, BT)
333
+ i_tg = i_b * NT + i_t
334
+ bos, eos = i_b * T, i_b * T + T
335
+
336
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
337
+
338
+ # offset calculation
339
+ q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
340
+ k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
341
+ do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
342
+ dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
343
+ s_qk = K if HEAD_FIRST else H*K
344
+ s_vo = V if HEAD_FIRST else H*V
345
+ s_g = 1 if HEAD_FIRST else H
346
+ dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V
347
+
348
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
349
+ for i_k in range(tl.cdiv(K, BK)):
350
+ p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
351
+ p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
352
+ b_q = tl.load(p_q, boundary_check=(0, 1))
353
+ b_k = tl.load(p_k, boundary_check=(0, 1))
354
+ b_A += tl.dot(b_k, b_q)
355
+ p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
356
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
357
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype))
358
+
359
+ if USE_G:
360
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
361
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
362
+ b_g = tl.load(p_g, boundary_check=(0,))
363
+ b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g)
364
+ b_dv *= safe_exp(-b_g + b_g_last)[:, None]
365
+
366
+ mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :])
367
+ if USE_G:
368
+ b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
369
+ else:
370
+ b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty)
371
+ p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
372
+ p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
373
+ b_do = tl.load(p_do, boundary_check=(0, 1))
374
+ b_dv += tl.dot(b_A.to(b_do.dtype), b_do)
375
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
376
+
377
+
378
+ @triton.heuristics({
379
+ 'USE_G': lambda args: args['g'] is not None,
380
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
381
+ })
382
+ @triton.autotune(
383
+ configs=[
384
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
385
+ for num_warps in NUM_WARPS
386
+ for num_stages in [2, 3, 4]
387
+ ],
388
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
389
+ )
390
+ @triton.jit(do_not_specialize=['T'])
391
+ def chunk_bwd_kernel_dv_local(
392
+ q,
393
+ k,
394
+ g,
395
+ do,
396
+ dv,
397
+ offsets,
398
+ indices,
399
+ scale,
400
+ T,
401
+ H: tl.constexpr,
402
+ K: tl.constexpr,
403
+ V: tl.constexpr,
404
+ BT: tl.constexpr,
405
+ BK: tl.constexpr,
406
+ BV: tl.constexpr,
407
+ USE_G: tl.constexpr,
408
+ USE_OFFSETS: tl.constexpr,
409
+ HEAD_FIRST: tl.constexpr
410
+ ):
411
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
412
+ i_b, i_h = i_bh // H, i_bh % H
413
+ if USE_OFFSETS:
414
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
415
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
416
+ T = eos - bos
417
+ else:
418
+ bos, eos = i_b * T, i_b * T + T
419
+
420
+ # offset calculation
421
+ q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
422
+ k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K
423
+ do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
424
+ dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V
425
+ s_qk = K if HEAD_FIRST else H*K
426
+ s_vo = V if HEAD_FIRST else H*V
427
+ s_g = 1 if HEAD_FIRST else H
428
+
429
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
430
+ for i_k in range(tl.cdiv(K, BK)):
431
+ p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
432
+ p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
433
+ b_q = tl.load(p_q, boundary_check=(0, 1))
434
+ b_k = tl.load(p_k, boundary_check=(0, 1))
435
+ b_A += tl.dot(b_k, b_q)
436
+
437
+ if USE_G:
438
+ g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h)
439
+ p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,))
440
+ b_g = tl.load(p_g, boundary_check=(0,))
441
+
442
+ mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :])
443
+ if USE_G:
444
+ b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty)
445
+ else:
446
+ b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty)
447
+
448
+ for i_v in range(tl.cdiv(V, BV)):
449
+ p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
450
+ p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
451
+ b_do = tl.load(p_do, boundary_check=(0, 1))
452
+ b_dv = tl.dot(b_A.to(b_do.dtype), b_do)
453
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
454
+
455
+
456
+ def chunk_fwd_o(
457
+ q: torch.Tensor,
458
+ k: torch.Tensor,
459
+ v: torch.Tensor,
460
+ h: torch.Tensor,
461
+ g: Optional[torch.Tensor] = None, # cumsum of log decay
462
+ scale: Optional[float] = None,
463
+ offsets: Optional[torch.LongTensor] = None,
464
+ indices: Optional[torch.LongTensor] = None,
465
+ head_first: bool = True,
466
+ chunk_size: int = 64
467
+ ) -> torch.Tensor:
468
+ if head_first:
469
+ B, H, T, K, V = *q.shape, v.shape[-1]
470
+ else:
471
+ B, T, H, K, V = *q.shape, v.shape[-1]
472
+ if scale is None:
473
+ scale = k.shape[-1] ** -0.5
474
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
475
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
476
+
477
+ o = torch.empty_like(v)
478
+
479
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
480
+ chunk_fwd_kernel_o[grid](
481
+ q,
482
+ k,
483
+ v,
484
+ h,
485
+ g,
486
+ o,
487
+ offsets,
488
+ indices,
489
+ scale,
490
+ T=T,
491
+ H=H,
492
+ K=K,
493
+ V=V,
494
+ BT=BT,
495
+ HEAD_FIRST=head_first
496
+ )
497
+ return o
498
+
499
+
500
+ def chunk_bwd_dv(
501
+ q: torch.Tensor,
502
+ k: torch.Tensor,
503
+ g: torch.Tensor,
504
+ do: torch.Tensor,
505
+ dh: torch.Tensor,
506
+ scale: float,
507
+ offsets: Optional[torch.LongTensor] = None,
508
+ indices: Optional[torch.LongTensor] = None,
509
+ head_first: bool = True,
510
+ chunk_size: int = 64
511
+ ) -> torch.Tensor:
512
+ if head_first:
513
+ B, H, T, K, V = *k.shape, do.shape[-1]
514
+ else:
515
+ B, T, H, K, V = *k.shape, do.shape[-1]
516
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
517
+ # H100 can have larger block size
518
+ if check_shared_mem('hopper', k.device.index):
519
+ CONST_TILING = 128
520
+ elif check_shared_mem:
521
+ CONST_TILING = 64
522
+ else:
523
+ CONST_TILING = 32
524
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
525
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
526
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
527
+ NV = triton.cdiv(V, BV)
528
+
529
+ dv = torch.empty_like(do)
530
+ grid = (NV, NT, B * H)
531
+ chunk_bwd_kernel_dv[grid](
532
+ q,
533
+ k,
534
+ g,
535
+ do,
536
+ dv,
537
+ dh,
538
+ offsets,
539
+ indices,
540
+ scale,
541
+ T=T,
542
+ H=H,
543
+ K=K,
544
+ V=V,
545
+ BT=BT,
546
+ BK=BK,
547
+ BV=BV,
548
+ HEAD_FIRST=head_first
549
+ )
550
+ return dv
551
+
552
+
553
+ def chunk_bwd_dv_local(
554
+ q: torch.Tensor,
555
+ k: torch.Tensor,
556
+ g: torch.Tensor,
557
+ do: torch.Tensor,
558
+ dh: torch.Tensor,
559
+ scale: float,
560
+ offsets: Optional[torch.LongTensor] = None,
561
+ indices: Optional[torch.LongTensor] = None,
562
+ head_first: bool = True,
563
+ chunk_size: int = 64
564
+ ) -> torch.Tensor:
565
+ if head_first:
566
+ B, H, T, K, V = *k.shape, do.shape[-1]
567
+ else:
568
+ B, T, H, K, V = *k.shape, do.shape[-1]
569
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
570
+ # H100 can have larger block size
571
+ if check_shared_mem('hopper', k.device.index):
572
+ CONST_TILING = 128
573
+ elif check_shared_mem:
574
+ CONST_TILING = 64
575
+ else:
576
+ CONST_TILING = 32
577
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
578
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
579
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
580
+
581
+ dv = torch.empty_like(do)
582
+ grid = (NT, B * H)
583
+ chunk_bwd_kernel_dv_local[grid](
584
+ q,
585
+ k,
586
+ g,
587
+ do,
588
+ dv,
589
+ offsets,
590
+ indices,
591
+ scale,
592
+ T=T,
593
+ H=H,
594
+ K=K,
595
+ V=V,
596
+ BT=BT,
597
+ BK=BK,
598
+ BV=BV,
599
+ HEAD_FIRST=head_first
600
+ )
601
+ return dv
602
+
603
+
604
+ def chunk_bwd_dqkwg(
605
+ q: torch.Tensor,
606
+ k: torch.Tensor,
607
+ v: torch.Tensor,
608
+ g: torch.Tensor,
609
+ do: torch.Tensor,
610
+ h: torch.Tensor,
611
+ dh: torch.Tensor,
612
+ dv: Optional[torch.Tensor] = None,
613
+ w: Optional[torch.Tensor] = None,
614
+ offsets: Optional[torch.LongTensor] = None,
615
+ indices: Optional[torch.LongTensor] = None,
616
+ chunk_size: int = 64,
617
+ scale: float = 1.0,
618
+ head_first: bool = True,
619
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
620
+
621
+ if head_first:
622
+ B, H, T, K, V = *k.shape, v.shape[-1]
623
+ else:
624
+ B, T, H, K, V = *k.shape, v.shape[-1]
625
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
626
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
627
+
628
+ CONST_TILING = 64 if check_shared_mem() else 32
629
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
630
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
631
+ NK = triton.cdiv(K, BK)
632
+ dq = torch.empty_like(q)
633
+ dk = torch.empty_like(k)
634
+ dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None
635
+ dw = torch.empty_like(w) if w is not None else None
636
+
637
+ grid = (NK, NT, B * H)
638
+ chunk_bwd_kernel_dqkwg[grid](
639
+ q=q,
640
+ k=k,
641
+ v=v,
642
+ h=h,
643
+ g=g,
644
+ do=do,
645
+ dh=dh,
646
+ dv=dv,
647
+ w=w,
648
+ dw=dw,
649
+ dq=dq,
650
+ dk=dk,
651
+ dg=dg,
652
+ offsets=offsets,
653
+ indices=indices,
654
+ scale=scale,
655
+ B=B,
656
+ T=T,
657
+ H=H,
658
+ K=K,
659
+ V=V,
660
+ BT=BT,
661
+ BK=BK,
662
+ BV=BV,
663
+ HEAD_FIRST=head_first
664
+ )
665
+
666
+ if dg is not None:
667
+ dg = dg.sum(0)
668
+ return dq, dk, dw, dg
fla/ops/common/chunk_scaled_dot_kkt.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_indices
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
15
+ })
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
19
+ for BK in [32, 64, 128]
20
+ for num_warps in [2, 4, 8]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['H', 'K', 'BT', 'USE_OFFSETS'],
24
+ )
25
+ @triton.jit(do_not_specialize=['T'])
26
+ def chunk_scaled_dot_kkt_fwd_kernel(
27
+ k,
28
+ beta,
29
+ A,
30
+ offsets,
31
+ indices,
32
+ T,
33
+ H: tl.constexpr,
34
+ K: tl.constexpr,
35
+ BT: tl.constexpr,
36
+ BK: tl.constexpr,
37
+ HEAD_FIRST: tl.constexpr,
38
+ USE_OFFSETS: tl.constexpr,
39
+ ):
40
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
41
+ i_b, i_h = i_bh // H, i_bh % H
42
+ if USE_OFFSETS:
43
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_b * T, i_b * T + T
48
+ o_t = tl.arange(0, BT)
49
+
50
+ if HEAD_FIRST:
51
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
52
+ else:
53
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
54
+ b_beta = tl.load(p_beta, boundary_check=(0,))
55
+
56
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
57
+ for i_k in range(tl.cdiv(K, BK)):
58
+ if HEAD_FIRST:
59
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
60
+ else:
61
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
62
+ b_k = tl.load(p_k, boundary_check=(0, 1))
63
+ b_kb = b_k * b_beta[:, None]
64
+ b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
65
+
66
+ b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
67
+ if HEAD_FIRST:
68
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
69
+ else:
70
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
71
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
72
+
73
+
74
+ def chunk_scaled_dot_kkt_fwd(
75
+ k: torch.Tensor,
76
+ beta: torch.Tensor,
77
+ cu_seqlens: Optional[torch.LongTensor],
78
+ head_first: bool = False,
79
+ chunk_size: int = 64,
80
+ output_dtype: torch.dtype = torch.float32
81
+ ) -> torch.Tensor:
82
+ r"""
83
+ Compute beta * K * K^T.
84
+
85
+ Args:
86
+ k (torch.Tensor):
87
+ The key tensor of shape `[B, T, H, K]` if not `head_first` else `[B, H, T, K]`.
88
+ beta (torch.Tensor):
89
+ The beta tensor of shape `[B, T, H]` if not `head_first` else `[B, H, T]`.
90
+ cu_seqlens (torch.LongTensor):
91
+ The cumulative sequence lengths of the input tensor.
92
+ Default: None
93
+ head_first (bool):
94
+ If False, the input/output tensor is in the shape of `[B, T, H, K]`.
95
+ If True, the input/output tensor is in the shape of `[B, H, T, K]`.
96
+ Default: False
97
+ chunk_size (int):
98
+ The chunk size. Default: 64.
99
+ output_dtype (torch.dtype):
100
+ The dtype of the output tensor. Default: `torch.float32`
101
+
102
+ Returns:
103
+ beta * K * K^T of shape `[B, T, H, BT]` if not `head_first` else `[B, H, T, BT]`,
104
+ where `BT` is the chunk size.
105
+ """
106
+ if head_first:
107
+ B, H, T, K = k.shape
108
+ else:
109
+ B, T, H, K = k.shape
110
+ BT = chunk_size
111
+ indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
112
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices)
113
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=output_dtype)
114
+ chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
115
+ k=k,
116
+ beta=beta,
117
+ A=A,
118
+ offsets=cu_seqlens,
119
+ indices=indices,
120
+ T=T,
121
+ H=H,
122
+ K=K,
123
+ BT=BT,
124
+ HEAD_FIRST=head_first
125
+ )
126
+ return A
fla/ops/common/fused_recurrent.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils import chunk_global_cumsum
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps)
23
+ for num_warps in [1, 2, 4]
24
+ ],
25
+ key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"],
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def fused_recurrent_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ g,
33
+ gk,
34
+ gv,
35
+ o,
36
+ h0,
37
+ ht,
38
+ offsets,
39
+ scale,
40
+ T,
41
+ B: tl.constexpr,
42
+ H: tl.constexpr,
43
+ K: tl.constexpr,
44
+ V: tl.constexpr,
45
+ BK: tl.constexpr,
46
+ BV: tl.constexpr,
47
+ REVERSE: tl.constexpr,
48
+ USE_G: tl.constexpr,
49
+ USE_GK: tl.constexpr,
50
+ USE_GV: tl.constexpr,
51
+ USE_INITIAL_STATE: tl.constexpr,
52
+ STORE_FINAL_STATE: tl.constexpr,
53
+ USE_OFFSETS: tl.constexpr,
54
+ HEAD_FIRST: tl.constexpr
55
+ ):
56
+ # indices
57
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
61
+ all = T
62
+ T = eos - bos
63
+ else:
64
+ bos, eos = i_n * T, i_n * T + T
65
+ all = B * T
66
+
67
+ if HEAD_FIRST:
68
+ p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
69
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
70
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
71
+ p_o = o + (i_k * B*H + i_nh) * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
72
+ if USE_G:
73
+ p_g = g + i_nh * T + ((T-1) if REVERSE else 0)
74
+ if USE_GK:
75
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
76
+ if USE_GV:
77
+ p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
78
+ else:
79
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
80
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
81
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
82
+ p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
83
+ if USE_G:
84
+ p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h
85
+ if USE_GK:
86
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
87
+ if USE_GV:
88
+ p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
89
+
90
+ mask_k = (i_k * BK + tl.arange(0, BK)) < K
91
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
92
+ mask_h = mask_k[None, :] & mask_v[:, None]
93
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
94
+
95
+ if USE_INITIAL_STATE:
96
+ p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
97
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
98
+
99
+ for _ in range(0, T):
100
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
101
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
102
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
103
+ if USE_GK:
104
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
105
+ b_h = b_h * exp(b_gk[None, :])
106
+ if USE_GV:
107
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
108
+ b_h = b_h * exp(b_gv[:, None])
109
+ if USE_G:
110
+ b_g = tl.load(p_g).to(tl.float32)
111
+ b_h = b_h * exp(b_g)
112
+ b_h += b_k[None, :] * b_v[:, None]
113
+ b_o = b_h * b_q[None, :]
114
+ b_o = tl.sum(b_o, axis=1)
115
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
116
+ p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
117
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
118
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
119
+ p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
120
+ if USE_GK:
121
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
122
+ if USE_GV:
123
+ p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
124
+ if USE_G:
125
+ p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H)
126
+
127
+ if STORE_FINAL_STATE:
128
+ p_ht = ht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
129
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
130
+
131
+
132
+ @triton.heuristics({
133
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
134
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
135
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
136
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
137
+ })
138
+ @triton.autotune(
139
+ configs=[
140
+ triton.Config({}, num_warps=num_warps)
141
+ for num_warps in [1, 2, 4]
142
+ ],
143
+ key=['BK', 'BV', 'USE_GK', 'USE_GV', 'USE_G'],
144
+ )
145
+ @triton.jit(do_not_specialize=['T'])
146
+ def fused_recurrent_bwd_kernel(
147
+ q,
148
+ k,
149
+ v,
150
+ g,
151
+ gk,
152
+ gv,
153
+ h0,
154
+ do,
155
+ dq,
156
+ dk,
157
+ dv,
158
+ dht,
159
+ dh0,
160
+ offsets,
161
+ scale,
162
+ T,
163
+ B: tl.constexpr,
164
+ H: tl.constexpr,
165
+ K: tl.constexpr,
166
+ V: tl.constexpr,
167
+ BK: tl.constexpr,
168
+ BV: tl.constexpr,
169
+ REVERSE: tl.constexpr,
170
+ USE_G: tl.constexpr,
171
+ USE_GK: tl.constexpr,
172
+ USE_GV: tl.constexpr,
173
+ USE_INITIAL_STATE: tl.constexpr,
174
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
175
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
176
+ USE_OFFSETS: tl.constexpr,
177
+ HEAD_FIRST: tl.constexpr
178
+ ):
179
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
180
+ i_n, i_h = i_nh // H, i_nh % H
181
+ if USE_OFFSETS:
182
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
183
+ all = T
184
+ T = eos - bos
185
+ else:
186
+ bos, eos = i_n * T, i_n * T + T
187
+ all = B * T
188
+
189
+ if HEAD_FIRST:
190
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
191
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
192
+ p_do = do + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
193
+ p_dq = dq + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
194
+ if USE_G:
195
+ p_g = g + i_nh * T + ((T-1) if REVERSE else 0)
196
+ if USE_GK:
197
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
198
+ if USE_GV:
199
+ p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
200
+ else:
201
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
202
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
203
+ p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
204
+ p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
205
+ if USE_G:
206
+ p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h
207
+ if USE_GK:
208
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
209
+ if USE_GV:
210
+ p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
211
+
212
+ mask_k = i_k * BK + tl.arange(0, BK) < K
213
+ mask_v = i_v * BV + tl.arange(0, BV) < V
214
+ mask_h = mask_k[:, None] & mask_v[None, :]
215
+
216
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
217
+ if USE_INITIAL_STATE:
218
+ p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
219
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
220
+
221
+ for _ in range(0, T):
222
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
223
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
224
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
225
+ if USE_G:
226
+ b_g = tl.load(p_g).to(tl.float32)
227
+ b_h = b_h * exp(b_g)
228
+ if USE_GK:
229
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
230
+ b_h = b_h * exp(b_gk[:, None])
231
+ if USE_GV:
232
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
233
+ b_h = b_h * exp(b_gv[None, :])
234
+ b_h += b_k[:, None] * b_v[None, :]
235
+ b_dq = b_h * b_do[None, :]
236
+ b_dq = tl.sum(b_dq, axis=1) * scale
237
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k)
238
+
239
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
240
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
241
+ p_do += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
242
+ p_dq += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
243
+ if USE_G:
244
+ p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H)
245
+ if USE_GK:
246
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
247
+ if USE_GV:
248
+ p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
249
+
250
+ # sync threads
251
+ tl.debug_barrier()
252
+
253
+ if HEAD_FIRST:
254
+ p_q = q + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
255
+ p_k = k + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
256
+ p_v = v + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
257
+ p_do = do + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
258
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
259
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
260
+ if USE_G:
261
+ p_g = g + i_nh * T + ((T - 1) if not REVERSE else 0)
262
+ if USE_GK:
263
+ p_gk = gk + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
264
+ if USE_GV:
265
+ p_gv = gv + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
266
+ else:
267
+ p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
268
+ p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
269
+ p_v = v + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
270
+ p_do = do + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
271
+ p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
272
+ p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
273
+ if USE_G:
274
+ p_g = g + (bos + ((T - 1) if not REVERSE else 0)) * H + i_h
275
+ if USE_GK:
276
+ p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
277
+ if USE_GV:
278
+ p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
279
+
280
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
281
+ if USE_FINAL_STATE_GRADIENT:
282
+ p_dht = dht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
283
+ b_dh += tl.load(p_dht, mask=mask_h, other=0).to(tl.float32)
284
+
285
+ for _ in range(T):
286
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
287
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
288
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
289
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
290
+ b_dh += b_q[:, None] * b_do[None, :]
291
+ b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
292
+ b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
293
+ if USE_G:
294
+ b_g = tl.load(p_g).to(tl.float32)
295
+ b_dh *= exp(b_g)
296
+ if USE_GK:
297
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
298
+ b_dh *= exp(b_gk)[:, None]
299
+ if USE_GV:
300
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
301
+ b_dh *= exp(b_gv)[None, :]
302
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
303
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
304
+
305
+ p_q += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
306
+ p_k += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
307
+ p_v += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
308
+ p_do += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
309
+ p_dk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
310
+ p_dv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
311
+ if USE_G:
312
+ p_g += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H)
313
+ if USE_GK:
314
+ p_gk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
315
+ if USE_GV:
316
+ p_gv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
317
+
318
+ if STORE_INITIAL_STATE_GRADIENT:
319
+ p_dh0 = dh0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
320
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h)
321
+
322
+
323
+ def fused_recurrent_fwd(
324
+ q: torch.Tensor,
325
+ k: torch.Tensor,
326
+ v: torch.Tensor,
327
+ g: Optional[torch.Tensor] = None,
328
+ gk: Optional[torch.Tensor] = None,
329
+ gv: Optional[torch.Tensor] = None,
330
+ scale: Optional[float] = None,
331
+ initial_state: Optional[torch.Tensor] = None,
332
+ output_final_state: bool = False,
333
+ reverse: bool = False,
334
+ offsets: Optional[torch.LongTensor] = None,
335
+ head_first: bool = True
336
+ ):
337
+ if head_first:
338
+ B, H, T, K, V = *k.shape, v.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *k.shape, v.shape[-1]
341
+ N = B if offsets is None else len(offsets) - 1
342
+ BK, BV = min(K, 64), min(V, 64)
343
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
344
+
345
+ h0 = initial_state
346
+ if output_final_state:
347
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
348
+ else:
349
+ ht = None
350
+ o = q.new_empty(NK, *v.shape, dtype=torch.float32)
351
+
352
+ grid = (NV, NK, N * H)
353
+ fused_recurrent_fwd_kernel[grid](
354
+ q,
355
+ k,
356
+ v,
357
+ g,
358
+ gk,
359
+ gv,
360
+ o,
361
+ h0,
362
+ ht,
363
+ offsets,
364
+ scale,
365
+ T=T,
366
+ B=B,
367
+ H=H,
368
+ K=K,
369
+ V=V,
370
+ BK=BK,
371
+ BV=BV,
372
+ USE_G=g is not None,
373
+ USE_GK=gk is not None,
374
+ USE_GV=gv is not None,
375
+ REVERSE=reverse,
376
+ HEAD_FIRST=head_first
377
+ )
378
+ o = o.sum(0)
379
+ return o, ht
380
+
381
+
382
+ def fused_recurrent_bwd(
383
+ q: torch.Tensor,
384
+ k: torch.Tensor,
385
+ v: torch.Tensor,
386
+ g: Optional[torch.Tensor] = None,
387
+ gk: Optional[torch.Tensor] = None,
388
+ gv: Optional[torch.Tensor] = None,
389
+ o: Optional[torch.Tensor] = None,
390
+ do: Optional[torch.Tensor] = None,
391
+ dht: Optional[torch.Tensor] = None,
392
+ scale: Optional[float] = None,
393
+ initial_state: Optional[torch.Tensor] = None,
394
+ reverse: bool = False,
395
+ offsets: Optional[torch.LongTensor] = None,
396
+ head_first: bool = True
397
+ ):
398
+ if head_first:
399
+ B, H, T, K, V = *k.shape, v.shape[-1]
400
+ else:
401
+ B, T, H, K, V = *k.shape, v.shape[-1]
402
+ N = B if offsets is None else len(offsets) - 1
403
+
404
+ BK, BV = min(K, 64), min(V, 64)
405
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
406
+
407
+ dq = q.new_empty(NV, *q.shape, dtype=torch.float32)
408
+ dk = q.new_empty(NV, *k.shape, dtype=torch.float32)
409
+ dv = q.new_empty(NK, *v.shape, dtype=torch.float32)
410
+ h0 = initial_state
411
+ dh0 = torch.empty_like(initial_state) if initial_state is not None else None
412
+
413
+ grid = (NV, NK, N * H)
414
+ fused_recurrent_bwd_kernel[grid](
415
+ q,
416
+ k,
417
+ v,
418
+ g,
419
+ gk,
420
+ gv,
421
+ h0,
422
+ do,
423
+ dq,
424
+ dk,
425
+ dv,
426
+ dht,
427
+ dh0,
428
+ offsets,
429
+ scale,
430
+ B=B,
431
+ T=T,
432
+ H=H,
433
+ K=K,
434
+ V=V,
435
+ BK=BK,
436
+ BV=BV,
437
+ USE_G=g is not None,
438
+ USE_GK=gk is not None,
439
+ USE_GV=gv is not None,
440
+ REVERSE=reverse,
441
+ HEAD_FIRST=head_first
442
+ )
443
+ dq = dq.sum(0)
444
+ dk = dk.sum(0)
445
+ dv = dv.sum(0)
446
+ dg, dgk, dgv = None, None, None
447
+ if g is not None:
448
+ dg = chunk_global_cumsum(
449
+ (dq * q.float() - dk * k.float()).sum(-1),
450
+ reverse=not reverse,
451
+ offsets=offsets,
452
+ head_first=head_first
453
+ )
454
+ if gk is not None:
455
+ dgk = chunk_global_cumsum(
456
+ dq * q.float() - dk * k.float(),
457
+ reverse=not reverse,
458
+ offsets=offsets,
459
+ head_first=head_first
460
+ )
461
+ if gv is not None:
462
+ dgv = chunk_global_cumsum(
463
+ do.float() * o.float() - dv * v.float(),
464
+ reverse=not reverse,
465
+ offsets=offsets,
466
+ head_first=head_first
467
+ )
468
+
469
+ return dq, dk, dv, dg, dgk, dgv, dh0
470
+
471
+
472
+ class FusedRecurrentFunction(torch.autograd.Function):
473
+
474
+ @staticmethod
475
+ @input_guard
476
+ @autocast_custom_fwd
477
+ def forward(
478
+ ctx,
479
+ q: torch.Tensor,
480
+ k: torch.Tensor,
481
+ v: torch.Tensor,
482
+ g: Optional[torch.Tensor] = None,
483
+ gk: Optional[torch.Tensor] = None,
484
+ gv: Optional[torch.Tensor] = None,
485
+ scale: Optional[float] = None,
486
+ initial_state: Optional[torch.Tensor] = None,
487
+ output_final_state: bool = False,
488
+ reverse: bool = False,
489
+ offsets: Optional[torch.LongTensor] = None,
490
+ head_first: bool = True
491
+ ):
492
+ o, ht = fused_recurrent_fwd(
493
+ q=q,
494
+ k=k,
495
+ v=v,
496
+ g=g,
497
+ gk=gk,
498
+ gv=gv,
499
+ scale=scale,
500
+ initial_state=initial_state,
501
+ output_final_state=output_final_state,
502
+ reverse=reverse,
503
+ offsets=offsets,
504
+ head_first=head_first
505
+ )
506
+ ctx.save_for_backward(q, k, v, g, gk, gv, initial_state, o)
507
+ ctx.scale = scale
508
+ ctx.reverse = reverse
509
+ ctx.offsets = offsets
510
+ ctx.head_first = head_first
511
+ return o.to(q.dtype), ht
512
+
513
+ @staticmethod
514
+ @input_guard
515
+ @autocast_custom_bwd
516
+ def backward(ctx, do, dht):
517
+ q, k, v, g, gk, gv, initial_state, o = ctx.saved_tensors
518
+ # not supported yet.
519
+ if dht is not None:
520
+ if not dht.eq(0).all():
521
+ if g is not None:
522
+ assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
523
+ if gk is not None:
524
+ assert gk.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
525
+ if gv is not None:
526
+ assert gv.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
527
+ dq, dk, dv, dg, dgk, dgv, dh0 = fused_recurrent_bwd(
528
+ q=q,
529
+ k=k,
530
+ v=v,
531
+ g=g,
532
+ gk=gk,
533
+ gv=gv,
534
+ o=o,
535
+ do=do,
536
+ dht=dht,
537
+ scale=ctx.scale,
538
+ initial_state=initial_state,
539
+ reverse=ctx.reverse,
540
+ offsets=ctx.offsets,
541
+ head_first=ctx.head_first
542
+ )
543
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None, None
544
+
545
+
546
+ def fused_recurrent(
547
+ q: torch.Tensor,
548
+ k: torch.Tensor,
549
+ v: torch.Tensor,
550
+ g: Optional[torch.Tensor] = None,
551
+ gk: Optional[torch.Tensor] = None,
552
+ gv: Optional[torch.Tensor] = None,
553
+ scale: Optional[float] = None,
554
+ initial_state: Optional[torch.Tensor] = None,
555
+ output_final_state: bool = False,
556
+ reverse: bool = False,
557
+ cu_seqlens: Optional[torch.LongTensor] = None,
558
+ head_first: bool = True
559
+ ):
560
+ if scale is None:
561
+ scale = k.shape[-1] ** -0.5
562
+ return FusedRecurrentFunction.apply(
563
+ q,
564
+ k,
565
+ v,
566
+ g,
567
+ gk,
568
+ gv,
569
+ scale,
570
+ initial_state,
571
+ output_final_state,
572
+ reverse,
573
+ cu_seqlens,
574
+ head_first
575
+ )
fla/ops/common/utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from fla.utils import tensor_cache
9
+
10
+
11
+ @triton.autotune(
12
+ configs=[
13
+ triton.Config({}, num_warps=num_warps)
14
+ for num_warps in [4, 8, 16, 32]
15
+ ],
16
+ key=['B'],
17
+ )
18
+ @triton.jit
19
+ def prepare_position_ids_kernel(
20
+ y,
21
+ offsets,
22
+ B: tl.constexpr
23
+ ):
24
+ i_n = tl.program_id(0)
25
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
26
+ T = eos - bos
27
+
28
+ o = tl.arange(0, B)
29
+ for i in range(0, tl.cdiv(T, B) * B, B):
30
+ o_i = o + i
31
+ tl.store(y + bos + o_i, o_i, o_i < T)
32
+
33
+
34
+ @tensor_cache
35
+ def prepare_lens(offsets: torch.LongTensor) -> torch.LongTensor:
36
+ return offsets[1:] - offsets[:-1]
37
+
38
+
39
+ @tensor_cache
40
+ def prepare_position_ids(offsets: torch.LongTensor) -> torch.LongTensor:
41
+ return torch.cat([torch.arange(n, dtype=offsets.dtype, device=offsets.device) for n in prepare_lens(offsets).unbind()])
42
+
43
+
44
+ @tensor_cache
45
+ def prepare_sequence_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
46
+ return position_ids.eq(0).cumsum(0) - 1
47
+
48
+
49
+ @tensor_cache
50
+ def prepare_token_indices(offsets: torch.LongTensor) -> torch.LongTensor:
51
+ position_ids = prepare_position_ids(offsets)
52
+ return torch.stack([prepare_sequence_ids(position_ids), position_ids], 1).to(offsets)
53
+
54
+
55
+ @tensor_cache
56
+ def prepare_chunk_indices(
57
+ offsets: torch.LongTensor,
58
+ chunk_size: int
59
+ ) -> torch.LongTensor:
60
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(offsets), chunk_size).tolist()])
61
+ return torch.stack([prepare_sequence_ids(indices), indices], 1).to(offsets)
62
+
63
+
64
+ @tensor_cache
65
+ def prepare_chunk_offsets(
66
+ offsets: torch.LongTensor,
67
+ chunk_size: int
68
+ ) -> torch.LongTensor:
69
+ return torch.cat([offsets.new_tensor([0]), triton.cdiv(prepare_lens(offsets), chunk_size)]).cumsum(-1)
fla/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (34 kB). View file
 
fla/ops/delta_rule/fused_chunk.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ def fused_chunk_delta_rule(
4
+ **kwargs
5
+ ):
6
+ raise NotImplementedError("fused_chunk_delta_rule is deprecated. Please use chunk_delta_rule instead.")
fla/ops/forgetting_attn/parallel.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils import chunk_global_cumsum, chunk_local_cumsum
13
+ from fla.ops.utils.op import div, exp, log
14
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
24
+ for num_stages in [2, 3, 4, 5]
25
+ ],
26
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
27
+ )
28
+ @triton.jit
29
+ def parallel_forgetting_attn_fwd_kernel(
30
+ q,
31
+ k,
32
+ v,
33
+ g,
34
+ o,
35
+ lse,
36
+ scale,
37
+ offsets,
38
+ indices,
39
+ T,
40
+ B: tl.constexpr,
41
+ H: tl.constexpr,
42
+ HQ: tl.constexpr,
43
+ G: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BS: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr
51
+ ):
52
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
54
+ i_h = i_hq // G
55
+
56
+ if USE_OFFSETS:
57
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
58
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
59
+ T = eos - bos
60
+ else:
61
+ i_n = i_b
62
+ bos, eos = i_n * T, i_n * T + T
63
+
64
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
65
+ p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
66
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
67
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
68
+
69
+ # the Q block is kept in the shared memory throughout the whole kernel
70
+ # [BT, BK]
71
+ b_q = tl.load(p_q, boundary_check=(0, 1))
72
+ b_q = (b_q * scale).to(b_q.dtype)
73
+ # [BT,]
74
+ b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
75
+ # [BT, BV]
76
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
77
+
78
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
79
+ b_acc = tl.zeros([BT], dtype=tl.float32)
80
+
81
+ # [BT]
82
+ o_q = i_t * BT + tl.arange(0, BT)
83
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
84
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
85
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
86
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
87
+
88
+ # [BS]
89
+ o_k = i_s + tl.arange(0, BS)
90
+ # [BK, BS]
91
+ b_k = tl.load(p_k, boundary_check=(0, 1))
92
+ # [BS, BV]
93
+ b_v = tl.load(p_v, boundary_check=(0, 1))
94
+ # [BS,]
95
+ b_gk = tl.load(p_gk, boundary_check=(0,))
96
+ # [BT, BS]
97
+ b_s = tl.dot(b_q, b_k) + b_gq[:, None] - b_gk[None, :]
98
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
99
+
100
+ # [BT]
101
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
102
+ b_r = exp(b_mp - b_m)
103
+ # [BT, BS]
104
+ b_p = exp(b_s - b_m[:, None])
105
+ # [BT]
106
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
107
+ # [BT, BV]
108
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
109
+
110
+ b_mp = b_m
111
+
112
+ for i_s in range(i_t * BT - BS, -BS, -BS):
113
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
114
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
115
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
116
+
117
+ # [BK, BS]
118
+ b_k = tl.load(p_k, boundary_check=(0, 1))
119
+ # [BS, BV]
120
+ b_v = tl.load(p_v, boundary_check=(0, 1))
121
+ # [BS,]
122
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
123
+
124
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
125
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
126
+ # [BT, BS]
127
+ b_s = tl.dot(b_q, b_k) + b_gq[:, None] + (b_gn - b_gk)[None, :]
128
+
129
+ b_gq += b_gn - b_gp
130
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
131
+ b_r = exp(b_mp - b_m)
132
+ # [BT, BS]
133
+ b_p = exp(b_s - b_m[:, None])
134
+ # [BT]
135
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
136
+ # [BT, BV]
137
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
138
+
139
+ b_mp = b_m
140
+
141
+ b_o = div(b_o, b_acc[:, None])
142
+ b_m += log(b_acc)
143
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
144
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
145
+
146
+
147
+ @triton.jit
148
+ def parallel_forgetting_attn_bwd_kernel_preprocess(
149
+ o,
150
+ do,
151
+ delta,
152
+ B: tl.constexpr,
153
+ V: tl.constexpr
154
+ ):
155
+ i_n = tl.program_id(0)
156
+ o_d = tl.arange(0, B)
157
+ m_d = o_d < V
158
+
159
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
160
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
161
+ b_delta = tl.sum(b_o * b_do)
162
+
163
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
164
+
165
+
166
+ @triton.heuristics({
167
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
168
+ })
169
+ @triton.autotune(
170
+ configs=[
171
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
172
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
173
+ for num_stages in [2, 3, 4]
174
+ ],
175
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
176
+ )
177
+ @triton.jit(do_not_specialize=['T'])
178
+ def parallel_forgetting_attn_bwd_kernel_dq(
179
+ q,
180
+ k,
181
+ v,
182
+ g,
183
+ lse,
184
+ delta,
185
+ do,
186
+ dq,
187
+ dg,
188
+ scale,
189
+ offsets,
190
+ indices,
191
+ T,
192
+ B: tl.constexpr,
193
+ H: tl.constexpr,
194
+ HQ: tl.constexpr,
195
+ G: tl.constexpr,
196
+ K: tl.constexpr,
197
+ V: tl.constexpr,
198
+ BT: tl.constexpr,
199
+ BS: tl.constexpr,
200
+ BK: tl.constexpr,
201
+ BV: tl.constexpr,
202
+ USE_OFFSETS: tl.constexpr
203
+ ):
204
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
205
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
206
+ i_h = i_hq // G
207
+
208
+ if USE_OFFSETS:
209
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
210
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
211
+ T = eos - bos
212
+ else:
213
+ i_n = i_b
214
+ bos, eos = i_n * T, i_n * T + T
215
+
216
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
217
+ p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
218
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
219
+ p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,))
220
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
221
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
222
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
223
+
224
+ # [BT, BK]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale).to(b_q.dtype)
227
+ # [BT, BV]
228
+ b_do = tl.load(p_do, boundary_check=(0, 1))
229
+ # [BT]
230
+ b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
231
+ b_lse = tl.load(p_lse, boundary_check=(0,))
232
+ b_delta = tl.load(p_delta, boundary_check=(0,))
233
+
234
+ # [BT]
235
+ o_q = i_t * BT + tl.arange(0, BT)
236
+ # [BT, BK]
237
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
238
+ # [BT]
239
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
240
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
241
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
242
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
243
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
244
+
245
+ # [BS]
246
+ o_k = i_s + tl.arange(0, BS)
247
+ # [BK, BS]
248
+ b_k = tl.load(p_k, boundary_check=(0, 1))
249
+ # [BV, BS]
250
+ b_v = tl.load(p_v, boundary_check=(0, 1))
251
+ # [BS,]
252
+ b_gk = tl.load(p_gk, boundary_check=(0,))
253
+ # [BT, BS]
254
+ b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] - b_gk[None, :]
255
+ b_p = exp(tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf')))
256
+
257
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
258
+ b_dp = tl.dot(b_do, b_v)
259
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
260
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
261
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
262
+ # [BT]
263
+ b_dg += tl.sum(b_ds, 1)
264
+
265
+ for i_s in range(i_t * BT - BS, -BS, -BS):
266
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
267
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
268
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
269
+
270
+ # [BK, BS]
271
+ b_k = tl.load(p_k, boundary_check=(0, 1))
272
+ # [BV, BS]
273
+ b_v = tl.load(p_v, boundary_check=(0, 1))
274
+ # [BS,]
275
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
276
+
277
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
278
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
279
+ # [BT, BS]
280
+ b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] + (b_gn - b_gk)[None, :]
281
+ b_p = exp(b_s)
282
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
283
+ b_dp = tl.dot(b_do, b_v)
284
+ b_ds = b_p * (b_dp - b_delta[:, None])
285
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
286
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
287
+ # [BT]
288
+ b_dg += tl.sum(b_ds, 1)
289
+
290
+ b_gq += b_gn - b_gp
291
+
292
+ b_dq *= scale
293
+
294
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
295
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
296
+
297
+
298
+ @triton.heuristics({
299
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
300
+ })
301
+ @triton.autotune(
302
+ configs=[
303
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
304
+ for num_warps in [1, 2, 4, 8]
305
+ for num_stages in [2, 3, 4]
306
+ ],
307
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
308
+ )
309
+ @triton.jit(do_not_specialize=['T'])
310
+ def parallel_forgetting_attn_bwd_kernel_dkv(
311
+ q,
312
+ k,
313
+ v,
314
+ g,
315
+ lse,
316
+ delta,
317
+ do,
318
+ dk,
319
+ dv,
320
+ dg,
321
+ offsets,
322
+ indices,
323
+ scale,
324
+ T,
325
+ B: tl.constexpr,
326
+ H: tl.constexpr,
327
+ HQ: tl.constexpr,
328
+ G: tl.constexpr,
329
+ K: tl.constexpr,
330
+ V: tl.constexpr,
331
+ BT: tl.constexpr,
332
+ BS: tl.constexpr,
333
+ BK: tl.constexpr,
334
+ BV: tl.constexpr,
335
+ USE_OFFSETS: tl.constexpr
336
+ ):
337
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
338
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
339
+ i_h = i_hq // G
340
+
341
+ if USE_OFFSETS:
342
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
343
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
344
+ T = eos - bos
345
+ else:
346
+ i_n = i_b
347
+ bos, eos = i_n * T, i_n * T + T
348
+
349
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
350
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
351
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
352
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
353
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
354
+ p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,))
355
+
356
+ # [BT, BK]
357
+ b_k = tl.load(p_k, boundary_check=(0, 1))
358
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
359
+ # [BT, BV]
360
+ b_v = tl.load(p_v, boundary_check=(0, 1))
361
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
362
+ # [BT]
363
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
364
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
365
+
366
+ o_k = i_t * BT + tl.arange(0, BT)
367
+ m_k = o_k < T
368
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
369
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
370
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
371
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
372
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
373
+ p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
374
+
375
+ # [BS]
376
+ o_q = i_s + tl.arange(0, BS)
377
+ # [BS, BK]
378
+ b_q = tl.load(p_q, boundary_check=(0, 1))
379
+ b_q = (b_q * scale).to(b_q.dtype)
380
+ # [BS, BV]
381
+ b_do = tl.load(p_do, boundary_check=(0, 1))
382
+ # [BS]
383
+ b_lse = tl.load(p_lse, boundary_check=(0,))
384
+ b_delta = tl.load(p_delta, boundary_check=(0,))
385
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
386
+
387
+ m_q = o_q < T
388
+ m_s = (o_k[:, None] <= o_q[None, :]) & m_k[:, None] & m_q[None, :]
389
+ # [BT, BS]
390
+ b_s = tl.dot(b_k, tl.trans(b_q)) - b_gk[:, None] + (b_gq - b_lse)[None, :]
391
+ b_p = tl.where(m_s, exp(b_s), 0)
392
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
393
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
394
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
395
+ b_dp = tl.dot(b_v, tl.trans(b_do))
396
+ # [BT, BS]
397
+ b_ds = b_p * (b_dp - b_delta[None, :])
398
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
399
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
400
+ # [BT]
401
+ b_dg -= tl.sum(b_ds, 1)
402
+
403
+ b_gk -= tl.load(g + (bos + min((i_t + 1) * BT, T) - 1) * HQ + i_hq).to(tl.float32)
404
+ for i_s in range((i_t + 1) * BT, T, BS):
405
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
406
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
407
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
408
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
409
+ p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
410
+
411
+ # [BS]
412
+ o_q = i_s + tl.arange(0, BS)
413
+ # [BS, BK]
414
+ b_q = tl.load(p_q, boundary_check=(0, 1))
415
+ b_q = (b_q * scale).to(b_q.dtype)
416
+ # [BS, BV]
417
+ b_do = tl.load(p_do, boundary_check=(0, 1))
418
+ # [BS]
419
+ b_lse = tl.load(p_lse, boundary_check=(0,))
420
+ b_delta = tl.load(p_delta, boundary_check=(0,))
421
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
422
+
423
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
424
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
425
+ # [BT, BS]
426
+ b_s = tl.dot(b_k, tl.trans(b_q)) - (b_gk + b_gp)[:, None] + (b_gq - b_lse)[None, :]
427
+ b_p = exp(b_s)
428
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
429
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
430
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
431
+ b_dp = tl.dot(b_v, tl.trans(b_do))
432
+ # [BT, BS]
433
+ b_ds = b_p * (b_dp - b_delta[None, :])
434
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
435
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
436
+ # [BT]
437
+ b_dg -= tl.sum(b_ds, 1)
438
+
439
+ b_gk -= b_gn - b_gp
440
+
441
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
442
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
443
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
444
+
445
+
446
+ def parallel_forgetting_attn_fwd(
447
+ q: torch.Tensor,
448
+ k: torch.Tensor,
449
+ v: torch.Tensor,
450
+ g: torch.Tensor,
451
+ scale: float,
452
+ chunk_size: int = 128,
453
+ offsets: Optional[torch.LongTensor] = None,
454
+ indices: Optional[torch.LongTensor] = None,
455
+ ):
456
+ B, T, H, K, V = *k.shape, v.shape[-1]
457
+ HQ = q.shape[2]
458
+ G = HQ // H
459
+ BT = chunk_size
460
+ BK = max(16, triton.next_power_of_2(K))
461
+ assert V <= 256, "V must be less than or equal to 256"
462
+ if check_shared_mem('hopper'):
463
+ BS = min(64, max(16, triton.next_power_of_2(T)))
464
+ else:
465
+ BS = min(32, max(16, triton.next_power_of_2(T)))
466
+ BV = min(256, max(16, triton.next_power_of_2(V)))
467
+ NV = triton.cdiv(V, BV)
468
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
469
+
470
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
471
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
472
+
473
+ grid = (NV, NT, B * HQ)
474
+ parallel_forgetting_attn_fwd_kernel[grid](
475
+ q=q,
476
+ k=k,
477
+ v=v,
478
+ g=g,
479
+ o=o,
480
+ lse=lse,
481
+ scale=scale,
482
+ offsets=offsets,
483
+ indices=indices,
484
+ B=B,
485
+ T=T,
486
+ H=H,
487
+ HQ=HQ,
488
+ G=G,
489
+ K=K,
490
+ V=V,
491
+ BT=BT,
492
+ BS=BS,
493
+ BK=BK,
494
+ BV=BV,
495
+ )
496
+ return o, lse
497
+
498
+
499
+ def parallel_forgetting_attn_bwd_preprocess(
500
+ o: torch.Tensor,
501
+ do: torch.Tensor
502
+ ):
503
+ V = o.shape[-1]
504
+ delta = torch.empty_like(o[..., 0], dtype=torch.float)
505
+ parallel_forgetting_attn_bwd_kernel_preprocess[(delta.numel(),)](
506
+ o=o,
507
+ do=do,
508
+ delta=delta,
509
+ B=triton.next_power_of_2(V),
510
+ V=V,
511
+ )
512
+ return delta
513
+
514
+
515
+ def parallel_forgetting_attn_bwd(
516
+ q: torch.Tensor,
517
+ k: torch.Tensor,
518
+ v: torch.Tensor,
519
+ g: torch.Tensor,
520
+ o: torch.Tensor,
521
+ lse: torch.Tensor,
522
+ do: torch.Tensor,
523
+ scale: float = None,
524
+ chunk_size: int = 128,
525
+ offsets: Optional[torch.LongTensor] = None,
526
+ indices: Optional[torch.LongTensor] = None,
527
+ ):
528
+ B, T, H, K, V = *k.shape, v.shape[-1]
529
+ HQ = q.shape[2]
530
+ G = HQ // H
531
+ BT = chunk_size
532
+ BS = min(32, max(16, triton.next_power_of_2(T)))
533
+ BK = max(16, triton.next_power_of_2(K))
534
+ BV = max(16, triton.next_power_of_2(V))
535
+ NV = triton.cdiv(V, BV)
536
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
537
+
538
+ delta = parallel_forgetting_attn_bwd_preprocess(o, do)
539
+ dq = q.new_empty(B, T, HQ, K, dtype=q.dtype)
540
+ dk = q.new_empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float)
541
+ dv = q.new_empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float)
542
+ dg = q.new_empty(g.shape, dtype=torch.float)
543
+ # NOTE: the original `dg` can be destroyed during autotuning
544
+ # this is [a known triton issue](https://github.com/triton-lang/triton/issues/5082), which will be fixed in 3.3 (?)
545
+ # so we need to make a copy of `dg`
546
+ dg2 = q.new_empty(g.shape, dtype=torch.float)
547
+ grid = (NV, NT, B * HQ)
548
+ parallel_forgetting_attn_bwd_kernel_dq[grid](
549
+ q=q,
550
+ k=k,
551
+ v=v,
552
+ g=g,
553
+ lse=lse,
554
+ delta=delta,
555
+ do=do,
556
+ dq=dq,
557
+ dg=dg,
558
+ offsets=offsets,
559
+ indices=indices,
560
+ scale=scale,
561
+ T=T,
562
+ B=B,
563
+ H=H,
564
+ HQ=HQ,
565
+ G=G,
566
+ K=K,
567
+ V=V,
568
+ BT=BT,
569
+ BS=BS,
570
+ BK=BK,
571
+ BV=BV
572
+ )
573
+ parallel_forgetting_attn_bwd_kernel_dkv[grid](
574
+ q=q,
575
+ k=k,
576
+ v=v,
577
+ g=g,
578
+ lse=lse,
579
+ delta=delta,
580
+ do=do,
581
+ dk=dk,
582
+ dv=dv,
583
+ dg=dg2,
584
+ offsets=offsets,
585
+ indices=indices,
586
+ scale=scale,
587
+ T=T,
588
+ B=B,
589
+ H=H,
590
+ HQ=HQ,
591
+ G=G,
592
+ K=K,
593
+ V=V,
594
+ BT=BT,
595
+ BS=BS,
596
+ BK=BK,
597
+ BV=BV
598
+ )
599
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
600
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
601
+ dg = dg.add_(dg2)
602
+ return dq, dk, dv, dg
603
+
604
+
605
+ @torch.compile
606
+ class ParallelForgettingAttentionFunction(torch.autograd.Function):
607
+
608
+ @staticmethod
609
+ @input_guard
610
+ @autocast_custom_fwd
611
+ def forward(ctx, q, k, v, g, scale, offsets):
612
+ ctx.dtype = q.dtype
613
+ if check_shared_mem('hopper'):
614
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
615
+ else:
616
+ chunk_size = min(64, max(16, triton.next_power_of_2(q.shape[1])))
617
+ # 2-d indices denoting the offsets of chunks in each sequence
618
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
619
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
620
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
621
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
622
+
623
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=False)
624
+ o, lse = parallel_forgetting_attn_fwd(
625
+ q=q,
626
+ k=k,
627
+ v=v,
628
+ g=g,
629
+ scale=scale,
630
+ chunk_size=chunk_size,
631
+ offsets=offsets,
632
+ indices=indices
633
+ )
634
+ ctx.save_for_backward(q, k, v, g, o, lse)
635
+ ctx.chunk_size = chunk_size
636
+ ctx.offsets = offsets
637
+ ctx.indices = indices
638
+ ctx.scale = scale
639
+ return o.to(q.dtype)
640
+
641
+ @staticmethod
642
+ @input_guard
643
+ @autocast_custom_bwd
644
+ def backward(ctx, do):
645
+ q, k, v, g, o, lse = ctx.saved_tensors
646
+ dq, dk, dv, dg = parallel_forgetting_attn_bwd(
647
+ q=q,
648
+ k=k,
649
+ v=v,
650
+ g=g,
651
+ o=o,
652
+ lse=lse,
653
+ do=do,
654
+ scale=ctx.scale,
655
+ chunk_size=ctx.chunk_size,
656
+ offsets=ctx.offsets,
657
+ indices=ctx.indices
658
+ )
659
+ dg = chunk_global_cumsum(dg, reverse=True, head_first=False, offsets=ctx.offsets)
660
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(g), None, None, None, None, None, None, None, None
661
+
662
+
663
+ def parallel_forgetting_attn(
664
+ q: torch.Tensor,
665
+ k: torch.Tensor,
666
+ v: torch.Tensor,
667
+ g: torch.Tensor,
668
+ scale: Optional[float] = None,
669
+ cu_seqlens: Optional[torch.LongTensor] = None,
670
+ head_first: bool = False
671
+ ) -> torch.Tensor:
672
+ r"""
673
+ Args:
674
+ q (torch.Tensor):
675
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
676
+ k (torch.Tensor):
677
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
678
+ GQA will be applied if HQ is divisible by H.
679
+ v (torch.Tensor):
680
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
681
+ g (torch.Tensor):
682
+ Forget gates (in **log space**) of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
683
+ scale (Optional[int]):
684
+ Scale factor for attention scores.
685
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
686
+ cu_seqlens (torch.LongTensor):
687
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
688
+ consistent with the FlashAttention API.
689
+ head_first (Optional[bool]):
690
+ Whether the inputs are in the head-first format. Default: `False`.
691
+
692
+ Returns:
693
+ o (torch.Tensor):
694
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
695
+ """
696
+ if scale is None:
697
+ scale = k.shape[-1] ** -0.5
698
+ if cu_seqlens is not None:
699
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
700
+ if g is not None:
701
+ g = g.float()
702
+ if head_first:
703
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
704
+ g = rearrange(g, 'b h t -> b t h')
705
+ o = ParallelForgettingAttentionFunction.apply(q, k, v, g, scale, cu_seqlens)
706
+ if head_first:
707
+ o = rearrange(o, 'b t h d -> b h t d')
708
+ return o
fla/ops/gated_delta_rule/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_gated_delta_rule
2
+ from .fused_recurrent import fused_recurrent_gated_delta_rule
3
+
4
+ __all__ = [
5
+ "chunk_gated_delta_rule",
6
+ "fused_recurrent_gated_delta_rule"
7
+ ]
fla/ops/gated_delta_rule/chunk.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ from einops import rearrange
9
+
10
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
11
+ from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
12
+ from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
13
+ from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u
14
+ from fla.ops.utils import chunk_local_cumsum
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ def chunk_gated_delta_rule_fwd(
19
+ q: torch.Tensor,
20
+ k: torch.Tensor,
21
+ v: torch.Tensor,
22
+ g: torch.Tensor,
23
+ beta: torch.Tensor,
24
+ scale: float,
25
+ initial_state: torch.Tensor,
26
+ output_final_state: bool,
27
+ offsets: Optional[torch.LongTensor] = None,
28
+ indices: Optional[torch.LongTensor] = None,
29
+ head_first: bool = True,
30
+ chunk_size: int = 64
31
+ ):
32
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
33
+ # obtain WY representation. u is actually the new v.
34
+ w, u, Aw, Au = fwd_prepare_wy_repr(
35
+ k=k,
36
+ v=v,
37
+ beta=beta,
38
+ g=g,
39
+ offsets=offsets,
40
+ indices=indices,
41
+ head_first=head_first,
42
+ chunk_size=chunk_size
43
+ )
44
+
45
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
46
+ k=k,
47
+ w=w,
48
+ u=u,
49
+ g=g,
50
+ initial_state=initial_state,
51
+ output_final_state=output_final_state,
52
+ offsets=offsets,
53
+ indices=indices,
54
+ head_first=head_first,
55
+ chunk_size=chunk_size
56
+ )
57
+
58
+ # obtain output
59
+ o = chunk_fwd_o(
60
+ q=q,
61
+ k=k,
62
+ v=v_new,
63
+ h=h,
64
+ g=g,
65
+ scale=scale,
66
+ offsets=offsets,
67
+ indices=indices,
68
+ head_first=head_first,
69
+ chunk_size=chunk_size
70
+ )
71
+ return g, o, Aw, Au, final_state
72
+
73
+
74
+ def chunk_gated_delta_rule_bwd(
75
+ q: torch.Tensor,
76
+ k: torch.Tensor,
77
+ v: torch.Tensor,
78
+ g: torch.Tensor,
79
+ beta: torch.Tensor,
80
+ Aw: torch.Tensor,
81
+ Au: torch.Tensor,
82
+ scale: float,
83
+ initial_state: torch.Tensor,
84
+ do: torch.Tensor,
85
+ dht: torch.Tensor,
86
+ offsets: Optional[torch.LongTensor] = None,
87
+ indices: Optional[torch.LongTensor] = None,
88
+ head_first: bool = True,
89
+ chunk_size: int = 64
90
+ ):
91
+ T = q.shape[2] if head_first else q.shape[1]
92
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
93
+ w, u = fwd_recompute_w_u(
94
+ k=k,
95
+ v=v,
96
+ beta=beta,
97
+ Aw=Aw,
98
+ Au=Au,
99
+ offsets=offsets,
100
+ indices=indices,
101
+ head_first=head_first,
102
+ chunk_size=BT
103
+ )
104
+ h, v_new, _ = chunk_gated_delta_rule_fwd_h(
105
+ k=k,
106
+ w=w,
107
+ u=u,
108
+ g=g,
109
+ initial_state=initial_state,
110
+ output_final_state=False,
111
+ offsets=offsets,
112
+ indices=indices,
113
+ head_first=head_first,
114
+ chunk_size=BT
115
+ )
116
+ dv = chunk_bwd_dv_local(
117
+ q=q,
118
+ k=k,
119
+ g=g,
120
+ do=do,
121
+ dh=None,
122
+ scale=scale,
123
+ offsets=offsets,
124
+ indices=indices,
125
+ head_first=head_first,
126
+ chunk_size=BT
127
+ )
128
+ dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
129
+ q=q,
130
+ k=k,
131
+ w=w,
132
+ g=g,
133
+ h0=initial_state,
134
+ dht=dht,
135
+ do=do,
136
+ dv=dv,
137
+ scale=scale,
138
+ offsets=offsets,
139
+ indices=indices,
140
+ head_first=head_first,
141
+ chunk_size=BT
142
+ )
143
+ dq, dk, dw, dg = chunk_bwd_dqkwg(
144
+ q=q,
145
+ k=k,
146
+ v=v_new,
147
+ w=w,
148
+ g=g,
149
+ h=h,
150
+ dv=dv,
151
+ do=do,
152
+ dh=dh,
153
+ scale=scale,
154
+ offsets=offsets,
155
+ indices=indices,
156
+ head_first=head_first,
157
+ chunk_size=BT
158
+ )
159
+ dk2, dv, db, dg2 = bwd_prepare_wy_repr(
160
+ k=k,
161
+ v=v,
162
+ beta=beta,
163
+ g=g,
164
+ Aw=Aw,
165
+ Au=Au,
166
+ dw=dw,
167
+ du=dv,
168
+ offsets=offsets,
169
+ indices=indices,
170
+ head_first=head_first,
171
+ chunk_size=BT
172
+ )
173
+ dk.add_(dk2)
174
+ dg.add_(dg2)
175
+ assert dg.dtype == torch.float32, "dg should be fp32"
176
+ dg = chunk_local_cumsum(dg, chunk_size, reverse=True, offsets=offsets, indices=indices, head_first=head_first)
177
+ return dq, dk, dv, db, dg, dh0
178
+
179
+
180
+ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
181
+
182
+ @staticmethod
183
+ @input_guard
184
+ @autocast_custom_fwd
185
+ def forward(
186
+ ctx,
187
+ q: torch.Tensor,
188
+ k: torch.Tensor,
189
+ v: torch.Tensor,
190
+ g: torch.Tensor,
191
+ beta: torch.Tensor,
192
+ scale: float,
193
+ initial_state: torch.Tensor,
194
+ output_final_state: bool,
195
+ offsets: Optional[torch.LongTensor] = None,
196
+ head_first: bool = True,
197
+ use_qk_l2norm_in_kernel: bool = False
198
+ ):
199
+ chunk_size = 64
200
+ q_orig = q
201
+ k_orig = k
202
+
203
+ if use_qk_l2norm_in_kernel:
204
+ q = l2norm_fwd(q)
205
+ k = l2norm_fwd(k)
206
+
207
+ # 2-d indices denoting the offsets of chunks in each sequence
208
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
209
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
210
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
211
+ indices = None
212
+ if offsets is not None:
213
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
214
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
215
+
216
+ g, o, Aw, Au, final_state = chunk_gated_delta_rule_fwd(
217
+ q=q,
218
+ k=k,
219
+ v=v,
220
+ g=g,
221
+ beta=beta,
222
+ scale=scale,
223
+ initial_state=initial_state,
224
+ output_final_state=output_final_state,
225
+ offsets=offsets,
226
+ indices=indices,
227
+ head_first=head_first,
228
+ chunk_size=chunk_size,
229
+ )
230
+ ctx.save_for_backward(q_orig, k_orig, v, g, beta, Aw, Au, initial_state, offsets, indices)
231
+ ctx.chunk_size = chunk_size
232
+ ctx.scale = scale
233
+ ctx.head_first = head_first
234
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
235
+ return o.to(q.dtype), final_state
236
+
237
+ @staticmethod
238
+ @input_guard
239
+ @autocast_custom_bwd
240
+ def backward(
241
+ ctx,
242
+ do: torch.Tensor,
243
+ dht: torch.Tensor
244
+ ):
245
+ q, k, v, g, beta, Aw, Au, initial_state, offsets, indices = ctx.saved_tensors
246
+ if ctx.use_qk_l2norm_in_kernel:
247
+ q, q_orig = l2norm_fwd(q), q
248
+ k, k_orig = l2norm_fwd(k), k
249
+ dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
250
+ q=q,
251
+ k=k,
252
+ v=v,
253
+ g=g,
254
+ beta=beta,
255
+ Aw=Aw,
256
+ Au=Au,
257
+ scale=ctx.scale,
258
+ initial_state=initial_state,
259
+ do=do,
260
+ dht=dht,
261
+ offsets=offsets,
262
+ indices=indices,
263
+ head_first=ctx.head_first,
264
+ chunk_size=ctx.chunk_size
265
+ )
266
+ if ctx.use_qk_l2norm_in_kernel:
267
+ dq = l2norm_bwd(q_orig, dq)
268
+ dk = l2norm_bwd(k_orig, dk)
269
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None
270
+
271
+
272
+ @torch.compiler.disable
273
+ def chunk_gated_delta_rule(
274
+ q: torch.Tensor,
275
+ k: torch.Tensor,
276
+ v: torch.Tensor,
277
+ g: torch.Tensor,
278
+ beta: torch.Tensor,
279
+ scale: float = None,
280
+ initial_state: torch.Tensor = None,
281
+ output_final_state: bool = False,
282
+ cu_seqlens: Optional[torch.LongTensor] = None,
283
+ head_first: bool = False,
284
+ use_qk_l2norm_in_kernel: bool = False
285
+ ):
286
+ r"""
287
+ Args:
288
+ q (torch.Tensor):
289
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
290
+ k (torch.Tensor):
291
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
292
+ v (torch.Tensor):
293
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
294
+ g (torch.Tensor):
295
+ (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
296
+ beta (torch.Tensor):
297
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
298
+ scale (Optional[int]):
299
+ Scale factor for the RetNet attention scores.
300
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
301
+ initial_state (Optional[torch.Tensor]):
302
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
303
+ For equal-length input sequences, `N` equals the batch size `B`.
304
+ Default: `None`.
305
+ output_final_state (Optional[bool]):
306
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
307
+ cu_seqlens (torch.LongTensor):
308
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
309
+ consistent with the FlashAttention API.
310
+ head_first (Optional[bool]):
311
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
312
+ Default: `False`.
313
+
314
+ Returns:
315
+ o (torch.Tensor):
316
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
317
+ final_state (torch.Tensor):
318
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
319
+
320
+ Examples::
321
+ >>> import torch
322
+ >>> import torch.nn.functional as F
323
+ >>> from einops import rearrange
324
+ >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
325
+ # inputs with equal lengths
326
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
327
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
328
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
329
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
330
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
331
+ >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
332
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
333
+ >>> o, ht = chunk_gated_delta_rule(
334
+ q, k, v, g, beta,
335
+ initial_state=h0,
336
+ output_final_state=True,
337
+ head_first=False
338
+ )
339
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
340
+ >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
341
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
342
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
343
+ >>> o_var, ht_var = chunk_gated_delta_rule(
344
+ q, k, v, g, beta,
345
+ initial_state=h0,
346
+ output_final_state=True,
347
+ cu_seqlens=cu_seqlens,
348
+ head_first=False
349
+ )
350
+ """
351
+ assert q.dtype == k.dtype == v.dtype
352
+ assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
353
+ assert len(beta.shape) == 3, "beta must be of shape [B, H, T] if head_first=True, or [B, T, H] if head_first=False."
354
+
355
+ if cu_seqlens is not None:
356
+ if q.shape[0] != 1:
357
+ raise ValueError(
358
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
359
+ f"Please flatten variable-length inputs before processing."
360
+ )
361
+ if head_first:
362
+ raise RuntimeError(
363
+ "Sequences with variable lengths are not supported for head-first mode"
364
+ )
365
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
366
+ raise ValueError(
367
+ f"The number of initial states is expected to be equal to the number of input sequences, "
368
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
369
+ )
370
+ if head_first:
371
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
372
+ beta, g = map(lambda x: rearrange(x, 'b h t -> b t h'), (beta, g))
373
+ if scale is None:
374
+ scale = k.shape[-1] ** -0.5
375
+ else:
376
+ assert scale > 0, "Scale must be positive."
377
+ o, final_state = ChunkGatedDeltaRuleFunction.apply(
378
+ q,
379
+ k,
380
+ v,
381
+ g,
382
+ beta,
383
+ scale,
384
+ initial_state,
385
+ output_final_state,
386
+ cu_seqlens,
387
+ False,
388
+ use_qk_l2norm_in_kernel
389
+ )
390
+ if head_first:
391
+ o = rearrange(o, 'b t h v -> b h t v')
392
+ return o, final_state
fla/ops/generalized_delta_rule/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .dplr import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule
2
+ from .iplr import chunk_iplr_delta_rule, fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_dplr_delta_rule',
6
+ 'fused_recurrent_dplr_delta_rule',
7
+ 'chunk_iplr_delta_rule',
8
+ 'fused_recurrent_iplr_delta_rule'
9
+ ]
fla/ops/gla/fused_recurrent.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.common.fused_recurrent import fused_recurrent
9
+
10
+
11
+ def fused_recurrent_gla(
12
+ q: torch.Tensor,
13
+ k: torch.Tensor,
14
+ v: torch.Tensor,
15
+ gk: Optional[torch.Tensor] = None,
16
+ gv: Optional[torch.Tensor] = None,
17
+ scale: Optional[int] = None,
18
+ initial_state: Optional[torch.Tensor] = None,
19
+ output_final_state: bool = False,
20
+ reverse: bool = False,
21
+ cu_seqlens: Optional[torch.LongTensor] = None,
22
+ head_first: bool = True
23
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ r"""
25
+ Args:
26
+ q (torch.Tensor):
27
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
28
+ k (torch.Tensor):
29
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
30
+ v (torch.Tensor):
31
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
32
+ gk (torch.Tensor):
33
+ Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys.
34
+ gv (torch.Tensor):
35
+ Forget gates of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` applied to values.
36
+ scale (Optional[int]):
37
+ Scale factor for the attention scores.
38
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
39
+ initial_state (Optional[torch.Tensor]):
40
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
41
+ For equal-length input sequences, `N` equals the batch size `B`.
42
+ Default: `None`.
43
+ output_final_state (Optional[bool]):
44
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
45
+ reverse (Optional[bool]):
46
+ If `True`, process the state passing in reverse order. Default: `False`.
47
+ cu_seqlens (torch.LongTensor):
48
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
49
+ consistent with the FlashAttention API.
50
+ head_first (Optional[bool]):
51
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
52
+ Default: `True`.
53
+
54
+ Returns:
55
+ o (torch.Tensor):
56
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
57
+ final_state (torch.Tensor):
58
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
59
+
60
+ Examples::
61
+ >>> import torch
62
+ >>> import torch.nn.functional as F
63
+ >>> from einops import rearrange
64
+ >>> from fla.ops.gla import fused_recurrent_gla
65
+ # inputs with equal lengths
66
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
67
+ >>> q = torch.randn(B, T, H, K, device='cuda')
68
+ >>> k = torch.randn(B, T, H, K, device='cuda')
69
+ >>> v = torch.randn(B, T, H, V, device='cuda')
70
+ >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda'))
71
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
72
+ >>> o, ht = fused_recurrent_gla(q, k, v, g,
73
+ initial_state=h0,
74
+ output_final_state=True,
75
+ head_first=False)
76
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
77
+ >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g))
78
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
79
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
80
+ >>> o_var, ht_var = fused_recurrent_gla(q, k, v, g,
81
+ initial_state=h0,
82
+ output_final_state=True,
83
+ cu_seqlens=cu_seqlens,
84
+ head_first=False)
85
+ >>> assert o.allclose(o_var.view(o.shape))
86
+ >>> assert ht.allclose(ht_var)
87
+ """
88
+ if cu_seqlens is not None:
89
+ if q.shape[0] != 1:
90
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
91
+ f"Please flatten variable-length inputs before processing.")
92
+ if head_first:
93
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
94
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
95
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
96
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
97
+ if scale is None:
98
+ scale = k.shape[-1] ** -0.5
99
+ o, final_state = fused_recurrent(
100
+ q=q,
101
+ k=k,
102
+ v=v,
103
+ g=None,
104
+ gk=gk,
105
+ gv=gv,
106
+ scale=scale,
107
+ initial_state=initial_state,
108
+ output_final_state=output_final_state,
109
+ reverse=reverse,
110
+ cu_seqlens=cu_seqlens,
111
+ head_first=head_first
112
+ )
113
+ return o, final_state
fla/ops/gla/naive.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ def ceildiv(a, b):
9
+ return -(a // -b)
10
+
11
+
12
+ def naive_recurrent_gla(
13
+ q: torch.Tensor,
14
+ k: torch.Tensor,
15
+ v: torch.Tensor,
16
+ gk: torch.Tensor,
17
+ initial_state: Optional[torch.Tensor] = None,
18
+ output_final_state: bool = False
19
+ ):
20
+ dtype = q.dtype
21
+ q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk))
22
+ B, H, T, K, V = *q.shape, v.shape[-1]
23
+ o = torch.zeros_like(v)
24
+ scale = K ** -0.5
25
+
26
+ h = q.new_zeros(B, H, K, V, dtype=torch.float32)
27
+ if initial_state is not None:
28
+ h += initial_state.float()
29
+
30
+ for i in range(T):
31
+ q_i = q[:, :, i] * scale
32
+ k_i = k[:, :, i]
33
+ v_i = v[:, :, i]
34
+ gk_i = gk[:, :, i].exp()
35
+ kv_i = k_i[..., None] * v_i[..., None, :]
36
+ h = h * gk_i[..., None] + kv_i
37
+ o[:, :, i] = (q_i[..., None] * h).sum(-2)
38
+
39
+ if not output_final_state:
40
+ h = None
41
+ return o.to(dtype), h
fla/ops/gsa/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_gsa
4
+ from .fused_recurrent import fused_recurrent_gsa
5
+
6
+ __all__ = [
7
+ 'chunk_gsa',
8
+ 'fused_recurrent_gsa'
9
+ ]
fla/ops/hgrn/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_hgrn
4
+ from .fused_recurrent import fused_recurrent_hgrn
5
+
6
+ __all__ = [
7
+ 'chunk_hgrn',
8
+ 'fused_recurrent_hgrn'
9
+ ]
fla/ops/hgrn/fused_recurrent.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BD': BD}, num_warps=num_warps)
22
+ for BD in [32, 64, 128]
23
+ for num_warps in [1, 2, 4, 8]
24
+ ],
25
+ key=['D']
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def fused_recurrent_hgrn_fwd_kernel(
29
+ x,
30
+ g,
31
+ o,
32
+ h0,
33
+ ht,
34
+ offsets,
35
+ T,
36
+ D: tl.constexpr,
37
+ BD: tl.constexpr,
38
+ USE_INITIAL_STATE: tl.constexpr,
39
+ STORE_FINAL_STATE: tl.constexpr,
40
+ USE_OFFSETS: tl.constexpr
41
+ ):
42
+ i_d, i_n = tl.program_id(0), tl.program_id(1)
43
+ if USE_OFFSETS:
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_n * T, i_n * T + T
48
+
49
+ o_d = i_d * BD + tl.arange(0, BD)
50
+ mask = o_d < D
51
+
52
+ p_x = x + bos * D + o_d
53
+ p_g = g + bos * D + o_d
54
+ p_o = o + bos * D + o_d
55
+
56
+ b_h = tl.zeros([BD], dtype=tl.float32)
57
+ if USE_INITIAL_STATE:
58
+ p_h0 = h0 + i_n * D + o_d
59
+ b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
60
+ for _ in range(0, T):
61
+ b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
62
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
63
+ b_h = exp(b_g) * b_h + b_x
64
+ tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
65
+
66
+ p_x += D
67
+ p_g += D
68
+ p_o += D
69
+
70
+ if STORE_FINAL_STATE:
71
+ p_ht = ht + i_n * D + o_d
72
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
73
+
74
+
75
+ @triton.heuristics({
76
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
77
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
78
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
79
+ })
80
+ @triton.autotune(
81
+ configs=[
82
+ triton.Config({'BD': BD}, num_warps=num_warps)
83
+ for BD in [32, 64, 128]
84
+ for num_warps in [1, 2, 4, 8]
85
+ ],
86
+ key=['D']
87
+ )
88
+ @triton.jit(do_not_specialize=['T'])
89
+ def fused_recurrent_hgrn_bwd_kernel(
90
+ g,
91
+ o,
92
+ h0,
93
+ dx,
94
+ dg,
95
+ do,
96
+ dht,
97
+ dh0,
98
+ offsets,
99
+ T,
100
+ D: tl.constexpr,
101
+ BD: tl.constexpr,
102
+ USE_INITIAL_STATE: tl.constexpr,
103
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
104
+ USE_OFFSETS: tl.constexpr
105
+ ):
106
+ i_d, i_n = tl.program_id(0), tl.program_id(1)
107
+ if USE_OFFSETS:
108
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
109
+ T = eos - bos
110
+ else:
111
+ bos, eos = i_n * T, i_n * T + T
112
+
113
+ o_d = i_d * BD + tl.arange(0, BD)
114
+ mask = o_d < D
115
+
116
+ p_g = g + (bos + T - 1) * D + o_d
117
+ p_o = o + (bos + T - 2) * D + o_d
118
+ p_dx = dx + (bos + T - 1) * D + o_d
119
+ p_dg = dg + (bos + T - 1) * D + o_d
120
+ p_do = do + (bos + T - 1) * D + o_d
121
+
122
+ b_dh = tl.zeros([BD], dtype=tl.float32)
123
+ if USE_FINAL_STATE_GRADIENT:
124
+ p_dht = dht + i_n * D + o_d
125
+ b_dh += tl.load(p_dht, mask=mask, other=0).to(tl.float32)
126
+
127
+ for i in range(T - 1, -1, -1):
128
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
129
+ b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
130
+ if i > 0:
131
+ b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
132
+ elif USE_INITIAL_STATE:
133
+ b_o = tl.load(h0 + i_n * D + o_d, mask=mask, other=0).to(tl.float32)
134
+ else:
135
+ b_o = tl.zeros([BD], dtype=tl.float32)
136
+
137
+ b_dh = b_dh + b_do
138
+ b_dx = b_dh
139
+ b_dh = b_dh * exp(b_g)
140
+ b_dg = b_dh * b_o
141
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
142
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
143
+
144
+ p_g -= D
145
+ p_o -= D
146
+ p_dx -= D
147
+ p_dg -= D
148
+ p_do -= D
149
+
150
+ if USE_INITIAL_STATE:
151
+ p_dh0 = dh0 + i_n * D + o_d
152
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask)
153
+
154
+
155
+ def fused_recurrent_hgrn_fwd(
156
+ x: torch.Tensor,
157
+ g: torch.Tensor,
158
+ initial_state: torch.Tensor = None,
159
+ output_final_state: bool = False,
160
+ offsets: Optional[torch.LongTensor] = None,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ B, T, D = x.shape
163
+ N = B if offsets is None else len(offsets) - 1
164
+
165
+ o = torch.empty_like(x)
166
+ final_state = x.new_empty(N, D) if output_final_state else None
167
+
168
+ def grid(meta): return (triton.cdiv(D, meta['BD']), N)
169
+ fused_recurrent_hgrn_fwd_kernel[grid](
170
+ x=x,
171
+ g=g,
172
+ o=o,
173
+ h0=initial_state,
174
+ ht=final_state,
175
+ offsets=offsets,
176
+ T=T,
177
+ D=D
178
+ )
179
+ return o, final_state
180
+
181
+
182
+ def fused_recurrent_hgrn_bwd(
183
+ g: torch.Tensor,
184
+ o: torch.Tensor,
185
+ do: torch.Tensor,
186
+ dht: torch.Tensor = None,
187
+ initial_state: torch.Tensor = None,
188
+ offsets: Optional[torch.LongTensor] = None
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ B, T, D = do.shape
191
+ N = B if offsets is None else len(offsets) - 1
192
+
193
+ dx = torch.empty_like(o, dtype=torch.float)
194
+ dg = torch.empty_like(g, dtype=torch.float)
195
+ dh0 = torch.empty_like(initial_state, dtype=torch.float) if initial_state is not None else None
196
+ def grid(meta): return (triton.cdiv(D, meta['BD']), N)
197
+ fused_recurrent_hgrn_bwd_kernel[grid](
198
+ g=g,
199
+ o=o,
200
+ h0=initial_state,
201
+ dx=dx,
202
+ dg=dg,
203
+ do=do,
204
+ dht=dht,
205
+ dh0=dh0,
206
+ offsets=offsets,
207
+ T=T,
208
+ D=D
209
+ )
210
+ return dx, dg, dh0
211
+
212
+
213
+ class FusedRecurrentHGRNFunction(torch.autograd.Function):
214
+
215
+ @staticmethod
216
+ @input_guard
217
+ def forward(
218
+ ctx,
219
+ x: torch.Tensor,
220
+ g: torch.Tensor,
221
+ initial_state: torch.Tensor = None,
222
+ output_final_state: bool = False,
223
+ offsets: Optional[torch.LongTensor] = None
224
+ ):
225
+ o, ht = fused_recurrent_hgrn_fwd(
226
+ x=x,
227
+ g=g,
228
+ initial_state=initial_state,
229
+ output_final_state=output_final_state,
230
+ offsets=offsets
231
+ )
232
+ ctx.save_for_backward(g, o, initial_state)
233
+ ctx.offsets = offsets
234
+ return o, ht
235
+
236
+ @staticmethod
237
+ @input_guard
238
+ def backward(ctx, do, dht=None):
239
+ g, o, initial_state = ctx.saved_tensors
240
+ offsets = ctx.offsets
241
+
242
+ dx, dg, dh0 = fused_recurrent_hgrn_bwd(
243
+ g=g,
244
+ o=o,
245
+ do=do,
246
+ dht=dht,
247
+ initial_state=initial_state,
248
+ offsets=offsets
249
+ )
250
+ return dx, dg, dh0, None, None
251
+
252
+
253
+ @torch.compiler.disable
254
+ def fused_recurrent_hgrn(
255
+ x: torch.Tensor,
256
+ g: torch.Tensor,
257
+ initial_state: torch.Tensor = None,
258
+ output_final_state: bool = False,
259
+ cu_seqlens: Optional[torch.LongTensor] = None,
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ r"""
262
+ Args:
263
+ x (torch.Tensor):
264
+ inputs of shape `[B, T, D].
265
+ g (torch.Tensor):
266
+ Forget gates of shape `[B, T, D]`.
267
+ initial_state (Optional[torch.Tensor]):
268
+ Initial state of shape `[N, D]` for `N` input sequences.
269
+ For equal-length input sequences, `N` equals the batch size `B`.
270
+ Default: `None`.
271
+ output_final_state (Optional[bool]):
272
+ Whether to output the final state of shape `[N, D]`. Default: `False`.
273
+ cu_seqlens (torch.LongTensor):
274
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
275
+ consistent with the FlashAttention API.
276
+
277
+ Returns:
278
+ o (torch.Tensor):
279
+ Outputs of shape `[B, T, D]`.
280
+ final_state (torch.Tensor):
281
+ Final state of shape `[N, D]` if `output_final_state=True` else `None`.
282
+
283
+ Examples::
284
+ >>> import torch
285
+ >>> import torch.nn.functional as F
286
+ >>> from einops import rearrange
287
+ >>> from fla.ops.hgrn import fused_recurrent_hgrn
288
+ # inputs with equal lengths
289
+ >>> B, T, D = 4, 2048, 512
290
+ >>> x = torch.randn(B, T, D, device='cuda')
291
+ >>> g = F.logsigmoid(torch.randn(B, T, D, device='cuda'))
292
+ >>> h0 = torch.randn(B, D, device='cuda')
293
+ >>> o, ht = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True)
294
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
295
+ >>> x, g = map(lambda x: rearrange(x, 'b t d -> 1 (b t) d'), (x, g))
296
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
297
+ >>> cu_seqlens = x.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
298
+ >>> o_var, ht_var = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True, cu_seqlens=cu_seqlens)
299
+ >>> assert o.allclose(o_var.view(o.shape))
300
+ >>> assert ht.allclose(ht_var)
301
+ """
302
+ return FusedRecurrentHGRNFunction.apply(
303
+ x,
304
+ g,
305
+ initial_state,
306
+ output_final_state,
307
+ cu_seqlens
308
+ )