zaydzuhri commited on
Commit
b4ec538
·
verified ·
1 Parent(s): 463b500

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/layers/attn.py +240 -0
  2. fla/layers/bitattn.py +192 -0
  3. fla/layers/forgetting_attn.py +109 -0
  4. fla/layers/gated_deltanet.py +293 -0
  5. fla/layers/hgrn.py +168 -0
  6. fla/layers/hgrn2.py +211 -0
  7. fla/layers/lightnet.py +210 -0
  8. fla/layers/nsa.py +138 -0
  9. fla/layers/rebased.py +133 -0
  10. fla/layers/simple_gla.py +261 -0
  11. fla/modules/__pycache__/rotary.cpython-312.pyc +0 -0
  12. fla/modules/parallel.py +37 -0
  13. flame/__pycache__/config_manager.cpython-312.pyc +0 -0
  14. flame/components/__pycache__/checkpoint.cpython-312.pyc +0 -0
  15. flame/models/__init__.py +0 -0
  16. flame/models/fla.toml +67 -0
  17. flame/models/parallelize_fla.py +550 -0
  18. flame/tools/__pycache__/__init__.cpython-312.pyc +0 -0
  19. flame/tools/__pycache__/utils.cpython-312.pyc +0 -0
  20. flame/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  21. flame/utils/__pycache__/checkpoint.cpython-312.pyc +0 -0
  22. flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc +0 -0
  23. flame/utils/__pycache__/hf_utils.cpython-312.pyc +0 -0
  24. flame/utils/convert_hf_to_dcp.py +34 -0
  25. logs/none_99omtdbz/attempt_0/0/stdout.log +0 -0
  26. logs/none_99omtdbz/attempt_0/3/stdout.log +0 -0
  27. logs/none_99omtdbz/attempt_0/4/stdout.log +0 -0
  28. logs/none_99omtdbz/attempt_0/7/stdout.log +0 -0
  29. profile_trace/iteration_10752/rank4_trace.json +0 -0
  30. profile_trace/iteration_1536/rank1_trace.json +0 -0
  31. profile_trace/iteration_1536/rank2_trace.json +0 -0
  32. profile_trace/iteration_1536/rank4_trace.json +0 -0
  33. profile_trace/iteration_1536/rank7_trace.json +0 -0
  34. profile_trace/iteration_2048/rank0_trace.json +0 -0
  35. profile_trace/iteration_2048/rank3_trace.json +0 -0
  36. profile_trace/iteration_25088/rank0_trace.json +0 -0
  37. profile_trace/iteration_25088/rank1_trace.json +0 -0
  38. profile_trace/iteration_25088/rank2_trace.json +0 -0
  39. profile_trace/iteration_25088/rank4_trace.json +0 -0
  40. profile_trace/iteration_31232/rank3_trace.json +0 -0
  41. profile_trace/iteration_31232/rank7_trace.json +0 -0
  42. profile_trace/iteration_34304/rank0_trace.json +0 -0
  43. profile_trace/iteration_34304/rank2_trace.json +0 -0
  44. profile_trace/iteration_34304/rank5_trace.json +0 -0
  45. profile_trace/iteration_34304/rank6_trace.json +0 -0
  46. profile_trace/iteration_5120/rank3_trace.json +0 -0
  47. profile_trace/iteration_5120/rank4_trace.json +0 -0
  48. profile_trace/iteration_9728/rank3_trace.json +0 -0
  49. profile_trace/iteration_9728/rank4_trace.json +0 -0
  50. profile_trace/iteration_9728/rank6_trace.json +0 -0
fla/layers/attn.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RMSNorm, RotaryEmbedding
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+ try:
22
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
23
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
24
+ except ImportError:
25
+ warnings.warn(
26
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`."
27
+ " Falling back to use SDPA's attention implementation.",
28
+ category=ImportWarning
29
+ )
30
+ flash_attn_func = None
31
+
32
+ import os
33
+ if os.getenv("FLASH_ATTENTION_DISABLE", "0") == "1":
34
+ flash_attn_func = None
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ class Attention(nn.Module):
40
+
41
+ def __init__(
42
+ self,
43
+ hidden_size: int = 2048,
44
+ num_heads: int = 32,
45
+ num_kv_heads: Optional[int] = None,
46
+ qkv_bias: bool = False,
47
+ qk_norm: bool = False,
48
+ window_size: Optional[int] = None,
49
+ rope_theta: Optional[float] = 10000.,
50
+ max_position_embeddings: Optional[int] = None,
51
+ layer_idx: int = None
52
+ ):
53
+ super().__init__()
54
+
55
+ self.hidden_size = hidden_size
56
+ self.num_heads = num_heads
57
+ if num_kv_heads is None:
58
+ self.num_kv_heads = self.num_heads
59
+ else:
60
+ self.num_kv_heads = num_kv_heads
61
+ self.num_kv_groups = num_heads // self.num_kv_heads
62
+ self.head_dim = self.hidden_size // self.num_heads
63
+ self.kv_dim = self.num_kv_heads * self.head_dim
64
+ self.qkv_bias = qkv_bias
65
+ self.qk_norm = qk_norm
66
+
67
+ self.window_size = window_size
68
+ self.rope_theta = rope_theta
69
+ self.max_position_embeddings = max_position_embeddings
70
+ self.layer_idx = layer_idx
71
+
72
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
73
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
74
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
75
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
76
+
77
+ if qk_norm:
78
+ self.q_norm = RMSNorm(self.head_dim)
79
+ self.k_norm = RMSNorm(self.head_dim)
80
+
81
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
82
+
83
+ def forward(
84
+ self,
85
+ hidden_states: torch.Tensor,
86
+ attention_mask: Optional[torch.LongTensor] = None,
87
+ past_key_values: Optional[Cache] = None,
88
+ output_attentions: bool = False,
89
+ use_cache: bool = False,
90
+ **kwargs,
91
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
92
+ if attention_mask is not None:
93
+ assert len(attention_mask.shape) == 2, (
94
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
95
+ "for padding purposes (0 indicating padding). "
96
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
97
+ )
98
+
99
+ batch_size, q_len, _ = hidden_states.size()
100
+
101
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
102
+
103
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
104
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
105
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
106
+
107
+ if self.qk_norm:
108
+ q, k = self.q_norm(q), self.k_norm(k)
109
+
110
+ # equivalent to cu_seqlens in `flash_attn`
111
+ cu_seqlens = kwargs.get('cu_seqlens', None)
112
+
113
+ seqlen_offset, max_seqlen = 0, q_len
114
+ if past_key_values is not None:
115
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
116
+ max_seqlen = q.shape[1] + seqlen_offset
117
+
118
+ if attention_mask is not None:
119
+ # to deliminate the offsets of padding tokens
120
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
121
+ max_seqlen = q.shape[1] + max(seqlen_offset)
122
+
123
+ if self.max_position_embeddings is not None:
124
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
125
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
126
+
127
+ if past_key_values is not None:
128
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
129
+ k_cached, v_cached = past_key_values.update(
130
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
131
+ layer_idx=self.layer_idx,
132
+ offset=q_len,
133
+ cache_kwargs=dict(window_size=self.window_size)
134
+ )['attn_state']
135
+ if cache_has_content:
136
+ k, v = k_cached, v_cached
137
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
138
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
139
+
140
+ # Contains at least one padding token in the sequence
141
+ if attention_mask is not None:
142
+ if flash_attn_func is not None:
143
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
144
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
145
+ max_seqlen_q, max_seqlen_k = max_seq_lens
146
+ o = flash_attn_varlen_func(
147
+ q, k, v,
148
+ cu_seqlens_q=cu_seqlens_q,
149
+ cu_seqlens_k=cu_seqlens_k,
150
+ max_seqlen_q=max_seqlen_q,
151
+ max_seqlen_k=max_seqlen_k,
152
+ causal=True,
153
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
154
+ ) # B S H D
155
+ o = pad_input(o, indices_q, batch_size, q_len) # B S H D
156
+ else:
157
+ attention_mask = attention_mask.bool()
158
+ q = rearrange(q, 'b s h d -> b h s d') # B H S D
159
+ k = rearrange(k, 'b s h d -> b h s d') # B H S D
160
+ v = rearrange(v, 'b s h d -> b h s d') # B H S D
161
+ o = F.scaled_dot_product_attention(
162
+ q, k, v,
163
+ attn_mask=attention_mask,
164
+ dropout_p=0.0,
165
+ is_causal=False,
166
+ ) # B, H, S, D
167
+ o = rearrange(o, 'b h s d -> b s h d')
168
+ elif cu_seqlens is not None:
169
+ if flash_attn_func is not None:
170
+ o = flash_attn_varlen_func(
171
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
172
+ cu_seqlens_q=cu_seqlens,
173
+ cu_seqlens_k=cu_seqlens,
174
+ max_seqlen_q=max_seqlen,
175
+ max_seqlen_k=max_seqlen,
176
+ causal=True,
177
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
178
+ ).unsqueeze(0)
179
+ else:
180
+ # o = F.scaled_dot_product_attention(
181
+ # q.squeeze(0), k.squeeze(0), v.squeeze(0),
182
+ # dropout_p=0.0,
183
+ # is_causal=True
184
+ # ).unsqueeze(0)
185
+ raise NotImplementedError(
186
+ "SDPA does not support variable length inputs with cu_seqlens. "
187
+ "Please use flash_attn_func for variable length inputs."
188
+ )
189
+ else:
190
+ if flash_attn_func is not None:
191
+ o = flash_attn_func(
192
+ q, k, v,
193
+ causal=True,
194
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
195
+ )
196
+ else:
197
+ q = rearrange(q, 'b s h d -> b h s d') # B H S D
198
+ k = rearrange(k, 'b s h d -> b h s d') # B H S D
199
+ v = rearrange(v, 'b s h d -> b h s d') # B H S D
200
+ o = F.scaled_dot_product_attention(
201
+ q, k, v,
202
+ dropout_p=0.0,
203
+ is_causal=True
204
+ )
205
+ o = rearrange(o, 'b h s d -> b s h d')
206
+ o = o.reshape(batch_size, q_len, -1)
207
+ o = self.o_proj(o)
208
+
209
+ if not output_attentions:
210
+ attentions = None
211
+
212
+ return o, attentions, past_key_values
213
+
214
+ def _upad_input(self, q, k, v, attention_mask, q_len):
215
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
216
+ cache_mask = attention_mask[:, -seq_len:]
217
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
218
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
219
+ max_seqlen_k = seqlens.max().item()
220
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
221
+
222
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
223
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
224
+ if q_len == seq_len:
225
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
226
+ cu_seqlens_q = cu_seqlens_k
227
+ max_seqlen_q = max_seqlen_k
228
+ indices_q = indices_k
229
+ elif q_len == 1:
230
+ max_seqlen_q = 1
231
+ # There is a memcpy here, that is very bad.
232
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
233
+ indices_q = cu_seqlens_q[:-1]
234
+ q = q.squeeze(1)
235
+ else:
236
+ # The -q_len: slice assumes left padding.
237
+ attention_mask = attention_mask[:, -q_len:]
238
+ q, indices_q, cu_seqlens_q, max_seqlen_q, *_ = unpad_input(q, attention_mask)
239
+
240
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/bitattn.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RotaryEmbedding
17
+ from fla.modules.fused_bitlinear import FusedBitLinear
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+ try:
23
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
24
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
25
+ except ImportError:
26
+ warnings.warn(
27
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
28
+ category=ImportWarning
29
+ )
30
+ flash_attn_func = None
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class BitAttention(nn.Module):
36
+
37
+ def __init__(
38
+ self,
39
+ hidden_size: int = 2048,
40
+ num_heads: int = 32,
41
+ num_kv_heads: Optional[int] = None,
42
+ window_size: Optional[int] = None,
43
+ rope_theta: Optional[float] = 10000.,
44
+ max_position_embeddings: Optional[int] = None,
45
+ norm_eps: float = 1e-5,
46
+ layer_idx: int = None
47
+ ):
48
+ super().__init__()
49
+
50
+ self.num_heads = num_heads
51
+ if num_kv_heads is None:
52
+ self.num_kv_heads = self.num_heads
53
+ else:
54
+ self.num_kv_heads = num_kv_heads
55
+ self.num_kv_groups = num_heads // self.num_kv_heads
56
+ self.hidden_size = hidden_size
57
+ self.head_dim = self.hidden_size // self.num_heads
58
+ self.kv_dim = self.num_kv_heads * self.head_dim
59
+ self.kv_dim = self.num_kv_heads * self.head_dim
60
+ self.window_size = window_size
61
+ self.rope_theta = rope_theta
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.layer_idx = layer_idx
64
+
65
+ self.q_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
66
+ self.k_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
67
+ self.v_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
68
+ self.o_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
69
+
70
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.LongTensor] = None,
76
+ past_key_values: Optional[Cache] = None,
77
+ output_attentions: bool = False,
78
+ use_cache: bool = False,
79
+ **kwargs,
80
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
81
+ if attention_mask is not None:
82
+ assert len(attention_mask.shape) == 2, (
83
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
84
+ "for padding purposes (0 indicating padding). "
85
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
86
+ )
87
+
88
+ batch_size, q_len, _ = hidden_states.size()
89
+
90
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
91
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
92
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
93
+
94
+ # equivalent to cu_seqlens in `flash_attn`
95
+ cu_seqlens = kwargs.get('cu_seqlens', None)
96
+
97
+ seqlen_offset, max_seqlen = 0, q_len
98
+ if past_key_values is not None:
99
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
100
+ max_seqlen = q.shape[1] + seqlen_offset
101
+
102
+ if attention_mask is not None:
103
+ # to deliminate the offsets of padding tokens
104
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
105
+ max_seqlen = q.shape[1] + max(seqlen_offset)
106
+
107
+ if self.max_position_embeddings is not None:
108
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
109
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
110
+
111
+ if past_key_values is not None:
112
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
113
+ k_cached, v_cached = past_key_values.update(
114
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
115
+ layer_idx=self.layer_idx,
116
+ offset=q_len,
117
+ cache_kwargs=dict(window_size=self.window_size)
118
+ )['attn_state']
119
+ if cache_has_content:
120
+ k, v = k_cached, v_cached
121
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
122
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
123
+
124
+ if flash_attn_func is None:
125
+ raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
126
+
127
+ # Contains at least one padding token in the sequence
128
+ if attention_mask is not None:
129
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
130
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
131
+ max_seqlen_q, max_seqlen_k = max_seq_lens
132
+ o = flash_attn_varlen_func(
133
+ q, k, v,
134
+ cu_seqlens_q=cu_seqlens_q,
135
+ cu_seqlens_k=cu_seqlens_k,
136
+ max_seqlen_q=max_seqlen_q,
137
+ max_seqlen_k=max_seqlen_k,
138
+ causal=True,
139
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
140
+ )
141
+ o = pad_input(o, indices_q, batch_size, q_len)
142
+ elif cu_seqlens is not None:
143
+ o = flash_attn_varlen_func(
144
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
145
+ cu_seqlens_q=cu_seqlens,
146
+ cu_seqlens_k=cu_seqlens,
147
+ max_seqlen_q=max_seqlen,
148
+ max_seqlen_k=max_seqlen,
149
+ causal=True,
150
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
151
+ ).unsqueeze(0)
152
+ else:
153
+ o = flash_attn_func(
154
+ q, k, v,
155
+ causal=True,
156
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
157
+ )
158
+ o = o.reshape(batch_size, q_len, -1)
159
+ o = self.o_proj(o)
160
+
161
+ if not output_attentions:
162
+ attentions = None
163
+
164
+ return o, attentions, past_key_values
165
+
166
+ def _upad_input(self, q, k, v, attention_mask, q_len):
167
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
168
+ cache_mask = attention_mask[:, -seq_len:]
169
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
170
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
171
+ max_seqlen_k = seqlens.max().item()
172
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
173
+
174
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
175
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
176
+ if q_len == seq_len:
177
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
178
+ cu_seqlens_q = cu_seqlens_k
179
+ max_seqlen_q = max_seqlen_k
180
+ indices_q = indices_k
181
+ elif q_len == 1:
182
+ max_seqlen_q = 1
183
+ # There is a memcpy here, that is very bad.
184
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
185
+ indices_q = cu_seqlens_q[:-1]
186
+ q = q.squeeze(1)
187
+ else:
188
+ # The -q_len: slice assumes left padding.
189
+ attention_mask = attention_mask[:, -q_len:]
190
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
191
+
192
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/forgetting_attn.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from einops import rearrange
13
+ from transformers.utils import logging
14
+
15
+ from fla.modules import GroupNorm
16
+ from fla.ops.forgetting_attn.parallel import parallel_forgetting_attn
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class ForgettingAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ hidden_size: int = 2048,
30
+ num_heads: int = 32,
31
+ num_kv_heads: Optional[int] = None,
32
+ qkv_bias: bool = False,
33
+ qk_norm: bool = False,
34
+ window_size: Optional[int] = None,
35
+ use_output_gate: bool = False,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = self.hidden_size // self.num_heads
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+ self.qk_norm = qk_norm
51
+
52
+ self.window_size = window_size
53
+ self.use_output_gate = use_output_gate
54
+ self.layer_idx = layer_idx
55
+
56
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias)
57
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
58
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
59
+ self.f_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
60
+
61
+ if use_output_gate:
62
+ self.g_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
63
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
64
+
65
+ if qk_norm:
66
+ self.q_norm = GroupNorm(
67
+ num_groups=self.num_heads,
68
+ hidden_size=self.hidden_size,
69
+ is_rms_norm=True,
70
+ )
71
+ self.k_norm = GroupNorm(
72
+ num_groups=self.num_kv_heads,
73
+ hidden_size=self.kv_dim,
74
+ is_rms_norm=True,
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.LongTensor] = None,
81
+ past_key_values: Optional[Cache] = None,
82
+ output_attentions: bool = False,
83
+ use_cache: bool = False,
84
+ **kwargs,
85
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
86
+ if attention_mask is not None:
87
+ assert len(attention_mask.shape) == 2, (
88
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
89
+ "for padding purposes (0 indicating padding). "
90
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
91
+ )
92
+
93
+ cu_seqlens = kwargs.get('cu_seqlens', None)
94
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
95
+ f = F.logsigmoid(self.f_proj(hidden_states).float())
96
+ if self.qk_norm:
97
+ q, k = self.q_norm(q), self.k_norm(k)
98
+
99
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
100
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
101
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
102
+
103
+ o = parallel_forgetting_attn(q, k, v, f, cu_seqlens=cu_seqlens)
104
+ o = rearrange(o, '... h d -> ... (h d)')
105
+ if self.use_output_gate:
106
+ o = self.g_proj(hidden_states).sigmoid() * o
107
+ o = self.o_proj(o)
108
+
109
+ return o, None, past_key_values
fla/layers/gated_deltanet.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from torch.nn import functional as F
13
+
14
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
15
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
16
+
17
+ if TYPE_CHECKING:
18
+ from transformers.processing_utils import Unpack
19
+
20
+ from fla.models.utils import Cache
21
+
22
+
23
+ @torch.compile
24
+ def elu_p1(x):
25
+ return (F.elu(x, 1., False) + 1.).to(x)
26
+
27
+
28
+ @torch.compile
29
+ def sum_norm(x):
30
+ return (x / x.sum(-1, keepdim=True)).to(x)
31
+
32
+
33
+ class GatedDeltaNet(nn.Module):
34
+ """
35
+ The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa
36
+
37
+ Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters.
38
+
39
+ Parameter alloation when use_gate=True:
40
+ - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each
41
+ - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each
42
+ - Others are ignorably small.
43
+ - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size
44
+ NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim.
45
+
46
+ Parameter allocation when use_gate=False:
47
+ - 1 * hidden_size * hidden_size for the q_proj and k_proj each
48
+ - 2 * hidden_size * hidden_size for the v_proj and o_proj each
49
+ - Others are ignorably small.
50
+ - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size
51
+
52
+ Args:
53
+ hidden_size (int, Optional):
54
+ The hidden size of the input. Default: 2048.
55
+ expand_v (float, Optional):
56
+ The expansion ratio for the value dim. Default: 2.0.
57
+ head_dim (int, Optional):
58
+ The dimension of each head. Default: 256.
59
+ num_heads (int, Optional):
60
+ The number of heads. Default: 4.
61
+ mode (str, Optional):
62
+ Which Gated DeltaNet kernel to use.
63
+ Currently available: `chunk` and `fused_recurrent`.
64
+ Default: `chunk`.
65
+ use_beta (bool, Optional):
66
+ Whether to use beta. Default: `True`.
67
+ use_gate (bool, Optional):
68
+ Whether to use output gate. Default: `True`.
69
+ use_short_conv (bool, Optional):
70
+ Whether to use short convolutions. Default: `True`.
71
+ conv_size (int, Optional):
72
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
73
+ conv_bias (bool, Optional):
74
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
75
+ layer_idx (int, Optional):
76
+ The index of the layer. Default: None.
77
+ norm_eps (float, Optional):
78
+ The epsilon value for the normalization layer. Default: 1e-5.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ hidden_size: int = 2048,
84
+ expand_v: float = 2,
85
+ head_dim: int = 256,
86
+ num_heads: int = 6,
87
+ mode: str = 'chunk',
88
+ use_gate: bool = True,
89
+ use_short_conv: bool = True,
90
+ conv_size: int = 4,
91
+ conv_bias: bool = False,
92
+ layer_idx: int = None,
93
+ norm_eps: float = 1e-5,
94
+ **kwargs
95
+ ) -> GatedDeltaNet:
96
+ super().__init__()
97
+
98
+ self.mode = mode
99
+
100
+ self.hidden_size = hidden_size
101
+ self.expand_v = expand_v
102
+
103
+ self.use_gate = use_gate
104
+ self.use_short_conv = use_short_conv
105
+ self.conv_size = conv_size
106
+ self.conv_bias = conv_bias
107
+
108
+ self.head_dim = head_dim
109
+ self.num_heads = num_heads
110
+
111
+ self.key_dim = int(self.num_heads * self.head_dim)
112
+ self.value_dim = int(self.key_dim * self.expand_v)
113
+ self.head_k_dim = head_dim
114
+ self.head_v_dim = int(head_dim * self.expand_v)
115
+ self.layer_idx = layer_idx
116
+
117
+ # Consistency check: Ensure expand_v produces integer values
118
+ if not math.isclose(self.key_dim * expand_v, self.value_dim, rel_tol=1e-5):
119
+ raise ValueError(
120
+ f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
121
+ f"Resulting value_dim would be {self.key_dim * expand_v}, which is invalid for nn.Linear."
122
+ )
123
+ if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
124
+ raise ValueError(
125
+ f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. "
126
+ f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated."
127
+ )
128
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
129
+
130
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
131
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
132
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
133
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
134
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
135
+
136
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
137
+ self.A_log = nn.Parameter(torch.log(A))
138
+ self.A_log._no_weight_decay = True
139
+ # hard coded for now
140
+ dt_min = 0.001
141
+ dt_max = 0.1
142
+ dt_init_floor = 1e-4
143
+ dt = torch.exp(
144
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
145
+ + math.log(dt_min)
146
+ )
147
+ dt = torch.clamp(dt, min=dt_init_floor)
148
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
149
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
150
+ self.dt_bias = nn.Parameter(inv_dt)
151
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
152
+ # name.endswith("bias") in param_grouping.py
153
+ self.dt_bias._no_weight_decay = True
154
+
155
+ if use_short_conv:
156
+ self.conv_size = conv_size
157
+ self.q_conv1d = ShortConvolution(
158
+ hidden_size=self.key_dim,
159
+ kernel_size=conv_size,
160
+ activation='silu'
161
+ )
162
+ self.k_conv1d = ShortConvolution(
163
+ hidden_size=self.key_dim,
164
+ kernel_size=conv_size,
165
+ activation='silu'
166
+ )
167
+ self.v_conv1d = ShortConvolution(
168
+ hidden_size=self.value_dim,
169
+ kernel_size=conv_size,
170
+ activation='silu'
171
+ )
172
+ else:
173
+ raise UserWarning(
174
+ "ShortConvolution is crucial to the performance. "
175
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
176
+ )
177
+ if use_gate:
178
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
179
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
180
+ else:
181
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
182
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
183
+
184
+ def forward(
185
+ self,
186
+ hidden_states: torch.Tensor,
187
+ attention_mask: Optional[torch.Tensor] = None,
188
+ past_key_values: Optional[Cache] = None,
189
+ use_cache: Optional[bool] = False,
190
+ output_attentions: Optional[bool] = False,
191
+ **kwargs: Unpack[Dict]
192
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
193
+ if attention_mask is not None:
194
+ assert len(attention_mask.shape) == 2, (
195
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
196
+ "for padding purposes (0 indicating padding). "
197
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
198
+ )
199
+
200
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
201
+ if self.training:
202
+ assert mode == 'chunk', "Only chunk mode is supported in training."
203
+
204
+ last_state = None
205
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
206
+ last_state = past_key_values[self.layer_idx]
207
+
208
+ cu_seqlens = kwargs.get('cu_seqlens', None)
209
+ if self.use_short_conv:
210
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
211
+ if last_state is not None:
212
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
213
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
214
+ q, conv_state_q = self.q_conv1d(
215
+ x=self.q_proj(hidden_states),
216
+ mask=conv_mask,
217
+ cache=conv_state_q,
218
+ output_final_state=use_cache,
219
+ cu_seqlens=cu_seqlens
220
+ )
221
+ k, conv_state_k = self.k_conv1d(
222
+ x=self.k_proj(hidden_states),
223
+ mask=conv_mask,
224
+ cache=conv_state_k,
225
+ output_final_state=use_cache,
226
+ cu_seqlens=cu_seqlens
227
+ )
228
+ v, conv_state_v = self.v_conv1d(
229
+ x=self.v_proj(hidden_states),
230
+ mask=conv_mask,
231
+ cache=conv_state_v,
232
+ output_final_state=use_cache,
233
+ cu_seqlens=cu_seqlens
234
+ )
235
+ else:
236
+ q = F.silu(self.q_proj(hidden_states))
237
+ k = F.silu(self.k_proj(hidden_states))
238
+ v = F.silu(self.v_proj(hidden_states))
239
+
240
+ q, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (q, k))
241
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
242
+ beta = self.b_proj(hidden_states).sigmoid()
243
+ g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)
244
+
245
+ # dealing with padding
246
+ if attention_mask is not None:
247
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
248
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
249
+
250
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
251
+ if mode == 'chunk':
252
+ o, recurrent_state = chunk_gated_delta_rule(
253
+ q=q,
254
+ k=k,
255
+ v=v,
256
+ g=g,
257
+ beta=beta,
258
+ initial_state=recurrent_state,
259
+ output_final_state=use_cache,
260
+ cu_seqlens=cu_seqlens,
261
+ head_first=False,
262
+ use_qk_l2norm_in_kernel=True
263
+ )
264
+ elif mode == 'fused_recurrent':
265
+ o, recurrent_state = fused_recurrent_gated_delta_rule(
266
+ q=q,
267
+ k=k,
268
+ v=v,
269
+ g=g,
270
+ beta=beta,
271
+ initial_state=recurrent_state,
272
+ output_final_state=use_cache,
273
+ cu_seqlens=cu_seqlens,
274
+ head_first=False,
275
+ use_qk_l2norm_in_kernel=True
276
+ )
277
+ if past_key_values is not None:
278
+ past_key_values.update(
279
+ recurrent_state=recurrent_state,
280
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
281
+ layer_idx=self.layer_idx,
282
+ offset=q.shape[1]
283
+ )
284
+
285
+ if self.use_gate:
286
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
287
+ o = self.o_norm(o, g)
288
+ else:
289
+ o = self.o_norm(o)
290
+ o = rearrange(o, 'b t h d -> b t (h d)')
291
+ o = self.o_proj(o)
292
+
293
+ return o, None, past_key_values
fla/layers/hgrn.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from fla.modules import FusedRMSNormGated, ShortConvolution
15
+ from fla.modules.activations import swiglu
16
+ from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class HGRNAttention(nn.Module):
25
+
26
+ def __init__(
27
+ self,
28
+ mode: str = 'chunk',
29
+ hidden_size: int = 1024,
30
+ expand_ratio: Optional[int] = 1,
31
+ use_short_conv: bool = False,
32
+ conv_size: int = 4,
33
+ conv_bias: bool = False,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ layer_idx: int = None
37
+ ) -> HGRNAttention:
38
+ super().__init__()
39
+
40
+ self.mode = mode
41
+ self.hidden_size = hidden_size
42
+ self.expand_ratio = expand_ratio
43
+ self.input_dim = int(hidden_size * expand_ratio)
44
+
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.conv_bias = conv_bias
48
+
49
+ self.layer_idx = layer_idx
50
+
51
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
52
+
53
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
54
+ self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
55
+ self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
56
+
57
+ if use_short_conv:
58
+ self.conv_size = conv_size
59
+ self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
60
+ self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
61
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
62
+
63
+ self.g_norm = FusedRMSNormGated(
64
+ hidden_size=self.input_dim,
65
+ elementwise_affine=elementwise_affine,
66
+ eps=norm_eps
67
+ )
68
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
69
+
70
+ def forward(
71
+ self,
72
+ hidden_states: torch.Tensor,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ past_key_values: Optional[Cache] = None,
75
+ use_cache: Optional[bool] = False,
76
+ output_attentions: Optional[bool] = False,
77
+ lower_bound: Optional[torch.Tensor] = None,
78
+ **kwargs: Unpack[Dict]
79
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
80
+ if attention_mask is not None:
81
+ assert len(attention_mask.shape) == 2, (
82
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
83
+ "for padding purposes (0 indicating padding). "
84
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
85
+ )
86
+
87
+ # launching the triton kernel for just one token will actually be slower
88
+ mode = 'fused_recurrent' if not self.training and hidden_states.shape[1] <= 64 else self.mode
89
+
90
+ last_state = None
91
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
92
+ last_state = past_key_values[self.layer_idx]
93
+
94
+ cu_seqlens = kwargs.get('cu_seqlens', None)
95
+ if self.use_short_conv:
96
+ conv_state_i, conv_state_f = None, None
97
+ if last_state is not None:
98
+ conv_state_i, conv_state_f = last_state['conv_state']
99
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
100
+ i, conv_state_i = self.i_conv1d(
101
+ x=self.i_proj(hidden_states),
102
+ mask=conv_mask,
103
+ cache=conv_state_i,
104
+ output_final_state=use_cache,
105
+ cu_seqlens=cu_seqlens
106
+ )
107
+ f, conv_state_f = self.f_conv1d(
108
+ x=self.f_proj(hidden_states),
109
+ mask=conv_mask,
110
+ cache=conv_state_f,
111
+ output_final_state=use_cache,
112
+ cu_seqlens=cu_seqlens
113
+ )
114
+ else:
115
+ i = self.i_proj(hidden_states)
116
+ f = self.f_proj(hidden_states)
117
+
118
+ # the lower bound for the first layer is zero
119
+ if lower_bound is None or self.layer_idx == 0:
120
+ i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
121
+ else:
122
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
123
+ i, f = swiglu(i, 1 - g), g.log()
124
+
125
+ # dealing with left-padding
126
+ if attention_mask is not None:
127
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
128
+
129
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
130
+ if mode == 'chunk':
131
+ if cu_seqlens is not None:
132
+ raise NotImplementedError("Chunk mode does not support variable-length sequences.")
133
+ o, recurrent_state = chunk_hgrn(
134
+ x=i,
135
+ g=f,
136
+ initial_state=recurrent_state,
137
+ output_final_state=use_cache,
138
+ )
139
+ elif mode == 'fused_recurrent':
140
+ o, recurrent_state = fused_recurrent_hgrn(
141
+ x=i,
142
+ g=f,
143
+ initial_state=recurrent_state,
144
+ output_final_state=use_cache,
145
+ cu_seqlens=cu_seqlens
146
+ )
147
+ else:
148
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
149
+
150
+ if past_key_values is not None:
151
+ past_key_values.update(
152
+ recurrent_state=recurrent_state,
153
+ conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None,
154
+ layer_idx=self.layer_idx,
155
+ offset=i.shape[2]
156
+ )
157
+
158
+ o = self.g_norm(o, self.g_proj(hidden_states))
159
+ o = self.o_proj(o)
160
+
161
+ return o, None, past_key_values
162
+
163
+ def state_size(self, **kwargs) -> int:
164
+ state_size = self.hidden_size
165
+ for module in self.children():
166
+ if isinstance(module, ShortConvolution):
167
+ state_size += module.state_size
168
+ return state_size
fla/layers/hgrn2.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from fla.modules import RMSNorm, ShortConvolution
16
+ from fla.modules.activations import swish
17
+ from fla.modules.layernorm import rms_norm_linear
18
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
19
+
20
+ if TYPE_CHECKING:
21
+ from transformers.processing_utils import Unpack
22
+
23
+ from fla.models.utils import Cache
24
+
25
+
26
+ class HGRN2Attention(nn.Module):
27
+
28
+ def __init__(
29
+ self,
30
+ mode: str = 'chunk',
31
+ hidden_size: int = 1024,
32
+ num_heads: Optional[int] = None,
33
+ expand_ratio: Optional[int] = 128,
34
+ use_short_conv: bool = False,
35
+ conv_size: int = 4,
36
+ conv_bias: bool = False,
37
+ elementwise_affine: Optional[bool] = True,
38
+ norm_eps: float = 1e-5,
39
+ layer_idx: int = None
40
+ ) -> HGRN2Attention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ self.hidden_size = hidden_size
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.conv_bias = conv_bias
58
+
59
+ self.forget_dim = int(self.num_heads * self.expand_ratio)
60
+ self.input_dim = hidden_size
61
+ self.layer_idx = layer_idx
62
+
63
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
64
+ assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}"
65
+ assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}"
66
+
67
+ self.head_f_dim = self.expand_ratio
68
+ self.head_i_dim = self.hidden_size // num_heads
69
+
70
+ self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
71
+ self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False)
72
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
73
+
74
+ if use_short_conv:
75
+ self.conv_size = conv_size
76
+ self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
77
+ self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None)
78
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
79
+
80
+ self.g_norm = RMSNorm(hidden_size=self.hidden_size, elementwise_affine=elementwise_affine, eps=norm_eps)
81
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
82
+
83
+ def forward(
84
+ self,
85
+ hidden_states: torch.Tensor,
86
+ attention_mask: Optional[torch.Tensor] = None,
87
+ past_key_values: Optional[Cache] = None,
88
+ use_cache: Optional[bool] = False,
89
+ output_attentions: Optional[bool] = False,
90
+ lower_bound: Optional[torch.Tensor] = None,
91
+ **kwargs: Unpack[Dict]
92
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
93
+ if attention_mask is not None:
94
+ assert len(attention_mask.shape) == 2, (
95
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
96
+ "for padding purposes (0 indicating padding). "
97
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
98
+ )
99
+
100
+ # launching the triton kernel for just one token will actually be slower
101
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
102
+
103
+ last_state = None
104
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
105
+ last_state = past_key_values[self.layer_idx]
106
+
107
+ cu_seqlens = kwargs.get('cu_seqlens', None)
108
+ if self.use_short_conv:
109
+ conv_state_q, conv_state_f, conv_state_i = None, None, None
110
+ if last_state is not None:
111
+ conv_state_q, conv_state_f, conv_state_i = last_state['conv_state']
112
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
113
+ q, conv_state_q = self.q_conv1d(
114
+ x=self.q_proj(hidden_states),
115
+ mask=conv_mask,
116
+ cache=conv_state_q,
117
+ output_final_state=use_cache,
118
+ cu_seqlens=cu_seqlens
119
+ )
120
+ f, conv_state_f = self.f_conv1d(
121
+ x=self.f_proj(hidden_states),
122
+ mask=conv_mask,
123
+ cache=conv_state_f,
124
+ output_final_state=use_cache,
125
+ cu_seqlens=cu_seqlens
126
+ )
127
+ i, conv_state_i = self.i_conv1d(
128
+ x=self.i_proj(hidden_states),
129
+ mask=conv_mask,
130
+ cache=conv_state_i,
131
+ output_final_state=use_cache,
132
+ cu_seqlens=cu_seqlens
133
+ )
134
+ else:
135
+ q = self.q_proj(hidden_states)
136
+ f = self.f_proj(hidden_states)
137
+ i = self.i_proj(hidden_states)
138
+
139
+ # dealing with left-padding
140
+ if attention_mask is not None:
141
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
142
+
143
+ q = swish(q)
144
+
145
+ # improve precision
146
+ f = f.float()
147
+
148
+ # the lower bound for the first layer is zero
149
+ if lower_bound is None or self.layer_idx == 0:
150
+ k, g = 1 - f.sigmoid(), F.logsigmoid(f)
151
+ else:
152
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
153
+ k, g = 1 - g, g.log()
154
+
155
+ q, k, g = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k.to(i), g))
156
+ i = rearrange(i, '... (h d) -> ... h d', d=self.head_i_dim)
157
+
158
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
159
+ if mode == 'fused_recurrent':
160
+ o, recurrent_state = fused_recurrent_gla(
161
+ q=q,
162
+ k=k,
163
+ v=i,
164
+ gk=g,
165
+ initial_state=recurrent_state,
166
+ output_final_state=use_cache,
167
+ cu_seqlens=cu_seqlens,
168
+ head_first=False
169
+ )
170
+ elif mode == 'fused_chunk':
171
+ o, recurrent_state = fused_chunk_gla(
172
+ q=q,
173
+ k=k,
174
+ v=i,
175
+ g=g,
176
+ initial_state=recurrent_state,
177
+ output_final_state=use_cache,
178
+ head_first=False
179
+ )
180
+ elif mode == 'chunk':
181
+ o, recurrent_state = chunk_gla(
182
+ q=q,
183
+ k=k,
184
+ v=i,
185
+ g=g,
186
+ initial_state=recurrent_state,
187
+ output_final_state=use_cache,
188
+ cu_seqlens=cu_seqlens,
189
+ head_first=False
190
+ )
191
+ else:
192
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
193
+
194
+ if past_key_values is not None:
195
+ past_key_values.update(
196
+ recurrent_state=recurrent_state,
197
+ conv_state=(conv_state_q, conv_state_f, conv_state_i) if self.use_short_conv else None,
198
+ layer_idx=self.layer_idx,
199
+ offset=q.shape[1]
200
+ )
201
+
202
+ o = rearrange(o, '... h d -> ... (h d)')
203
+ o = rms_norm_linear(o, self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
204
+ return o, None, past_key_values
205
+
206
+ def state_size(self, **kwargs) -> int:
207
+ state_size = self.forget_dim * self.head_i_dim
208
+ for module in self.children():
209
+ if isinstance(module, ShortConvolution):
210
+ state_size += module.state_size
211
+ return state_size
fla/layers/lightnet.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # ["You Only Scan Once: Efficient Multi-dimension Sequential Modeling with LightNet"](https://arxiv.org/abs/2405.21022)
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from fla.modules import FusedRMSNormGated, ShortConvolution
16
+ from fla.modules.fused_norm_gate import rms_norm_swish_gate_linear
17
+ from fla.ops.gla import chunk_gla, fused_recurrent_gla
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from fla.models.utils import Cache
23
+
24
+
25
+ class LightNetAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ mode: str = 'chunk',
30
+ hidden_size: int = 1024,
31
+ num_heads: Optional[int] = None,
32
+ expand_ratio: Optional[int] = 128,
33
+ use_short_conv: bool = False,
34
+ conv_size: int = 4,
35
+ conv_bias: bool = False,
36
+ gate_low_rank_dim: int = 128,
37
+ elementwise_affine: Optional[bool] = True,
38
+ norm_eps: float = 1e-5,
39
+ layer_idx: int = None
40
+ ) -> LightNetAttention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ self.hidden_size = hidden_size
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.conv_bias = conv_bias
58
+
59
+ self.key_dim = int(self.num_heads * self.expand_ratio)
60
+ self.value_dim = hidden_size
61
+ self.gate_low_rank_dim = gate_low_rank_dim
62
+ self.layer_idx = layer_idx
63
+
64
+ assert mode in ['chunk', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
65
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
66
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
67
+
68
+ self.head_f_dim = self.expand_ratio
69
+ self.head_i_dim = self.hidden_size // num_heads
70
+
71
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
72
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
73
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
74
+
75
+ if use_short_conv:
76
+ self.conv_size = conv_size
77
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None)
78
+ self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None)
79
+ self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation=None)
80
+
81
+ self.g_proj = nn.Sequential(
82
+ nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
83
+ nn.Linear(gate_low_rank_dim, hidden_size, bias=False)
84
+ )
85
+ self.g_norm = FusedRMSNormGated(
86
+ hidden_size=hidden_size,
87
+ elementwise_affine=elementwise_affine,
88
+ eps=norm_eps
89
+ )
90
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
91
+
92
+ def forward(
93
+ self,
94
+ hidden_states: torch.Tensor,
95
+ attention_mask: Optional[torch.Tensor] = None,
96
+ past_key_values: Optional[Cache] = None,
97
+ use_cache: Optional[bool] = False,
98
+ output_attentions: Optional[bool] = False,
99
+ **kwargs: Unpack[Dict]
100
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
101
+ if attention_mask is not None:
102
+ assert len(attention_mask.shape) == 2, (
103
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
104
+ "for padding purposes (0 indicating padding). "
105
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
106
+ )
107
+
108
+ # launching the triton kernel for just one token will actually be slower
109
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
110
+
111
+ last_state = None
112
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
113
+ last_state = past_key_values[self.layer_idx]
114
+
115
+ cu_seqlens = kwargs.get('cu_seqlens', None)
116
+ if self.use_short_conv:
117
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
118
+ if last_state is not None:
119
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
120
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
121
+ q, conv_state_q = self.q_conv1d(
122
+ x=self.q_proj(hidden_states),
123
+ mask=conv_mask,
124
+ cache=conv_state_q,
125
+ output_final_state=use_cache,
126
+ cu_seqlens=cu_seqlens
127
+ )
128
+ k, conv_state_k = self.k_conv1d(
129
+ x=self.k_proj(hidden_states),
130
+ mask=conv_mask,
131
+ cache=conv_state_k,
132
+ output_final_state=use_cache,
133
+ cu_seqlens=cu_seqlens
134
+ )
135
+ v, conv_state_v = self.v_conv1d(
136
+ x=self.v_proj(hidden_states),
137
+ mask=conv_mask,
138
+ cache=conv_state_v,
139
+ output_final_state=use_cache,
140
+ cu_seqlens=cu_seqlens
141
+ )
142
+ else:
143
+ q = self.q_proj(hidden_states)
144
+ k = self.k_proj(hidden_states)
145
+ v = self.v_proj(hidden_states)
146
+
147
+ # dealing with left-padding
148
+ if attention_mask is not None:
149
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
150
+
151
+ q = F.silu(q)
152
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k))
153
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_i_dim)
154
+ # TODO: this 2 steps took huge amount of time, which should be optimized
155
+ z = k.float().logcumsumexp(1)
156
+
157
+ if cu_seqlens is not None:
158
+ raise NotImplementedError("LightNet does not support variable-length sequences for now.")
159
+ k, g = torch.exp(k - z).to(k.dtype), (torch.cat((z[:, :1], z[:, :-1]), 1) - z).to(k.dtype)
160
+
161
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
162
+ if mode == 'fused_recurrent':
163
+ o, recurrent_state = fused_recurrent_gla(
164
+ q=q,
165
+ k=k,
166
+ v=v,
167
+ gk=g,
168
+ initial_state=recurrent_state,
169
+ output_final_state=use_cache,
170
+ cu_seqlens=cu_seqlens,
171
+ head_first=False
172
+ )
173
+ elif mode == 'chunk':
174
+ o, recurrent_state = chunk_gla(
175
+ q=q,
176
+ k=k,
177
+ v=v,
178
+ g=g,
179
+ initial_state=recurrent_state,
180
+ output_final_state=use_cache,
181
+ cu_seqlens=cu_seqlens,
182
+ head_first=False
183
+ )
184
+ else:
185
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
186
+
187
+ if past_key_values is not None:
188
+ past_key_values.update(
189
+ recurrent_state=recurrent_state,
190
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
191
+ layer_idx=self.layer_idx,
192
+ offset=q.shape[1]
193
+ )
194
+
195
+ o = rms_norm_swish_gate_linear(
196
+ rearrange(o, 'b t h d -> b t (h d)'),
197
+ self.g_proj(hidden_states),
198
+ self.g_norm.weight,
199
+ self.g_norm.bias,
200
+ self.o_proj.weight,
201
+ self.o_proj.bias
202
+ )
203
+ return o, None, past_key_values
204
+
205
+ def state_size(self, **kwargs) -> int:
206
+ state_size = self.key_dim * self.head_i_dim
207
+ for module in self.children():
208
+ if isinstance(module, ShortConvolution):
209
+ state_size += module.state_size
210
+ return state_size
fla/layers/nsa.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from transformers.utils import logging
12
+
13
+ from fla.modules import RotaryEmbedding
14
+ from fla.ops.nsa.parallel import parallel_nsa
15
+
16
+ if TYPE_CHECKING:
17
+ from fla.models.utils import Cache
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class NativeSparseAttention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ hidden_size: int = 2048,
27
+ num_heads: int = 64,
28
+ num_kv_heads: Optional[int] = 4,
29
+ head_dim: int = 64,
30
+ qkv_bias: bool = False,
31
+ block_size: Optional[int] = 64,
32
+ block_counts: Optional[Union[torch.LongTensor, int]] = 16,
33
+ window_size: Optional[int] = 512,
34
+ rope_theta: Optional[float] = 10000.,
35
+ max_position_embeddings: Optional[int] = None,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = head_dim
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+
51
+ self.block_size = block_size
52
+ self.block_counts = block_counts
53
+ self.window_size = window_size
54
+ self.rope_theta = rope_theta
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.layer_idx = layer_idx
57
+
58
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.qkv_bias)
59
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
60
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
61
+ self.g_proj = nn.Linear(self.hidden_size, self.num_heads * 3, bias=False)
62
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
63
+
64
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
65
+
66
+ def forward(
67
+ self,
68
+ hidden_states: torch.Tensor,
69
+ attention_mask: Optional[torch.LongTensor] = None,
70
+ past_key_values: Optional[Cache] = None,
71
+ output_attentions: bool = False,
72
+ use_cache: bool = False,
73
+ **kwargs,
74
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
75
+ if attention_mask is not None:
76
+ assert len(attention_mask.shape) == 2, (
77
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
78
+ "for padding purposes (0 indicating padding). "
79
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
80
+ )
81
+
82
+ batch_size, seq_len, _ = hidden_states.size()
83
+
84
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
85
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
86
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
87
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3)
88
+ g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
89
+
90
+ cu_seqlens = kwargs.get('cu_seqlens', None)
91
+
92
+ seqlen_offset, max_seqlen = 0, seq_len
93
+ if past_key_values is not None:
94
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
95
+ max_seqlen = q.shape[1] + seqlen_offset
96
+
97
+ if attention_mask is not None:
98
+ # to deliminate the offsets of padding tokens
99
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
100
+ max_seqlen = q.shape[1] + max(seqlen_offset)
101
+
102
+ if self.max_position_embeddings is not None:
103
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
104
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
105
+
106
+ if past_key_values is not None:
107
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
108
+ k_cached, v_cached = past_key_values.update(
109
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
110
+ layer_idx=self.layer_idx,
111
+ offset=seq_len,
112
+ cache_kwargs=dict(window_size=self.window_size)
113
+ )['attn_state']
114
+ if cache_has_content:
115
+ k, v = k_cached, v_cached
116
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
117
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
118
+
119
+ o = parallel_nsa(
120
+ q=q,
121
+ k=k,
122
+ v=v,
123
+ g_cmp=g_cmp,
124
+ g_slc=g_slc,
125
+ g_swa=g_swa,
126
+ block_size=self.block_size,
127
+ block_counts=self.block_counts,
128
+ window_size=self.window_size,
129
+ cu_seqlens=cu_seqlens,
130
+ head_first=False
131
+ )
132
+ o = o.reshape(batch_size, seq_len, -1)
133
+ o = self.o_proj(o)
134
+
135
+ if not output_attentions:
136
+ attentions = None
137
+
138
+ return o, attentions, past_key_values
fla/layers/rebased.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from einops import rearrange
15
+
16
+ from fla.modules.feature_map import RebasedFeatureMap
17
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
18
+ from fla.ops.rebased import parallel_rebased
19
+
20
+
21
+ class ReBasedLinearAttention(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int,
26
+ l_max: int = 2048,
27
+ feature_dim: int = 16,
28
+ num_key_value_heads: int = 16,
29
+ num_heads: int = 16,
30
+ use_gamma: Optional[bool] = True,
31
+ use_beta: Optional[bool] = True,
32
+ normalize: Optional[bool] = True,
33
+ causal: bool = True,
34
+ eps: float = 1e-5,
35
+ mode: str = "parallel",
36
+ layer_idx: Optional[int] = None,
37
+ **kwargs
38
+ ) -> ReBasedLinearAttention:
39
+ super().__init__()
40
+ self.hidden_size = hidden_size
41
+ self.l_max = l_max
42
+ self.mode = mode
43
+ assert self.mode in ["fused_chunk", "parallel", 'chunk']
44
+
45
+ self.feature_dim = feature_dim
46
+ self.num_key_value_heads = num_key_value_heads
47
+ self.num_heads = num_heads
48
+ self.head_dim = self.hidden_size // self.num_key_value_heads
49
+ self.use_gamma = use_gamma
50
+ self.use_beta = use_beta
51
+ self.normalize = normalize
52
+ self.causal = causal
53
+ self.eps = eps
54
+ self.mode = mode
55
+ self.layer_idx = layer_idx
56
+
57
+ self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
58
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
59
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
60
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
61
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
62
+ self.dropout = nn.Identity()
63
+
64
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
65
+ mode = self.mode
66
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
67
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
68
+ q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
69
+ if mode == "fused_chunk":
70
+ o = fused_chunk_linear_attn(
71
+ q=q,
72
+ k=k,
73
+ v=v,
74
+ normalize=True,
75
+ scale=1,
76
+ head_first=False
77
+ )
78
+ elif mode == 'chunk':
79
+ o = chunk_linear_attn(
80
+ q=q,
81
+ k=k,
82
+ v=v,
83
+ normalize=True,
84
+ scale=1,
85
+ head_first=False
86
+ )
87
+ elif mode == 'parallel':
88
+ assert q.shape[-1] <= 128
89
+ o = parallel_rebased(
90
+ q=q,
91
+ k=k,
92
+ v=v,
93
+ eps=self.eps,
94
+ use_scale=True,
95
+ use_normalize=True,
96
+ head_first=False
97
+ )
98
+ o = self.o_proj(o)
99
+ o = self.dropout(o)
100
+ return o
101
+
102
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
103
+ def forward_reference(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ filters: torch.Tensor = None,
107
+ *args,
108
+ **kwargs
109
+ ):
110
+ """
111
+ x (torch.Tensor): tensor of shape (b, d, t)
112
+ y (torch.Tensor): tensor of shape (b, d, t)
113
+ """
114
+ b, t, _ = hidden_states.size()
115
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
116
+
117
+ q = q.view(b, t, -1, self.feature_dim).transpose(1, 2)
118
+ k = k.view(b, t, -1, self.feature_dim).transpose(1, 2)
119
+ v = v.view(b, t, -1, self.head_dim).transpose(1, 2)
120
+
121
+ # Linear attention
122
+ q, k = self.feature_map(q), self.feature_map(k)
123
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
124
+
125
+ # Compute attention
126
+ if self.causal:
127
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
128
+ else:
129
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
130
+ y = rearrange(y, 'b h t d -> b t (h d)')
131
+ y = self.o_proj(y.to(hidden_states.dtype))
132
+ y = self.dropout(y)
133
+ return y.to(hidden_states.dtype)
fla/layers/simple_gla.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.modules.activations import ACT2FN
15
+ from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class SimpleGatedLinearAttention(nn.Module):
22
+ r"""
23
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
24
+ This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise.
25
+
26
+ Args:
27
+ mode (str, Optional):
28
+ Which GLA kernel to use.
29
+ Currently available: `chunk`.
30
+ Default: `chunk`.
31
+ hidden_size (int, Optional):
32
+ The hidden size of the input. Default: 1024.
33
+ expand_k (float, Optional):
34
+ The expansion ratio for the key dim. Default: 1.0.
35
+ expand_v (float, Optional):
36
+ The expansion ratio for the value dim. Default: 1.0.
37
+ num_heads (int, Optional):
38
+ The number of heads. Default: 4.
39
+ num_kv_heads (int, Optional):
40
+ The number of key/value heads, used for MQA. Default: None.
41
+ feature_map (str, Optional):
42
+ Feature map function applied to queries/keys. Default: None.
43
+ use_short_conv (bool, Optional):
44
+ Whether to use short convolutions. Default: `False`.
45
+ conv_size (int, Optional):
46
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
47
+ conv_bias (bool, Optional):
48
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
49
+ gate_fn (str, Optional):
50
+ The activation function for the output gate. Default: `swish`.
51
+ elementwise_affine (bool, Optional):
52
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
53
+ norm_eps (float, Optional):
54
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
55
+ gate_logit_normalizer (int, Optional):
56
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
57
+ fuse_norm (bool, Optional):
58
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
59
+ layer_idx (int, Optional):
60
+ The index of the layer. Default: None.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ mode: str = 'chunk',
66
+ hidden_size: int = 1024,
67
+ expand_k: float = 1.,
68
+ expand_v: float = 1.,
69
+ num_heads: int = 4,
70
+ num_kv_heads: Optional[int] = None,
71
+ feature_map: Optional[str] = None,
72
+ use_short_conv: bool = True,
73
+ conv_size: int = 4,
74
+ conv_bias: bool = False,
75
+ gate_fn: str = 'swish',
76
+ elementwise_affine: Optional[bool] = True,
77
+ norm_eps: float = 1e-5,
78
+ gate_logit_normalizer: int = 16,
79
+ fuse_norm: bool = True,
80
+ layer_idx: int = None,
81
+ ) -> SimpleGatedLinearAttention:
82
+ super().__init__()
83
+
84
+ self.mode = mode
85
+ self.hidden_size = hidden_size
86
+ self.expand_k = expand_k
87
+ self.expand_v = expand_v
88
+ self.num_heads = num_heads
89
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
90
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
91
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
92
+
93
+ self.use_short_conv = use_short_conv
94
+ self.conv_size = conv_size
95
+ self.conv_bias = conv_bias
96
+
97
+ self.key_dim = int(hidden_size * expand_k)
98
+ self.value_dim = int(hidden_size * expand_v)
99
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
100
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
101
+ self.layer_idx = layer_idx
102
+
103
+ assert mode in ['chunk', "fused_recurrent"], f"Not suppoerted mode `{mode}`."
104
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
105
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
106
+
107
+ self.head_k_dim = self.key_dim // num_heads
108
+ self.head_v_dim = self.value_dim // num_heads
109
+
110
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
111
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
112
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
113
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
114
+
115
+ if use_short_conv:
116
+ self.conv_size = conv_size
117
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
118
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
119
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
120
+
121
+ self.gk_proj = nn.Linear(hidden_size, self.num_heads)
122
+
123
+ if gate_fn == 'swish' and fuse_norm:
124
+ self.g_norm_swish_gate = FusedRMSNormGated(
125
+ hidden_size=self.head_v_dim,
126
+ elementwise_affine=elementwise_affine,
127
+ eps=norm_eps
128
+ )
129
+ self.fuse_norm_and_gate = True
130
+ else:
131
+ self.fuse_norm_and_gate = False
132
+ self.g_norm = RMSNorm(
133
+ hidden_size=self.head_v_dim,
134
+ elementwise_affine=elementwise_affine,
135
+ eps=norm_eps
136
+ )
137
+ self.gate_fn = ACT2FN[gate_fn]
138
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
139
+
140
+ self.gate_logit_normalizer = gate_logit_normalizer
141
+
142
+ def forward(
143
+ self,
144
+ hidden_states: torch.Tensor,
145
+ attention_mask: Optional[torch.Tensor] = None,
146
+ past_key_values: Optional[Cache] = None,
147
+ use_cache: Optional[bool] = False,
148
+ output_attentions: Optional[bool] = False,
149
+ **kwargs
150
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
151
+ if attention_mask is not None:
152
+ assert len(attention_mask.shape) == 2, (
153
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
154
+ "for padding purposes (0 indicating padding). "
155
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
156
+ )
157
+
158
+ # launching the triton kernel for just one token will actually be slower
159
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
160
+
161
+ last_state = None
162
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
163
+ last_state = past_key_values[self.layer_idx]
164
+
165
+ cu_seqlens = kwargs.get('cu_seqlens', None)
166
+ if self.use_short_conv:
167
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
168
+ if last_state is not None:
169
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
170
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
171
+ q, conv_state_q = self.q_conv1d(
172
+ x=self.q_proj(hidden_states),
173
+ mask=conv_mask,
174
+ cache=conv_state_q,
175
+ output_final_state=use_cache,
176
+ cu_seqlens=cu_seqlens
177
+ )
178
+ k, conv_state_k = self.k_conv1d(
179
+ x=self.k_proj(hidden_states),
180
+ mask=conv_mask,
181
+ cache=conv_state_k,
182
+ output_final_state=use_cache,
183
+ cu_seqlens=cu_seqlens
184
+ )
185
+ v, conv_state_v = self.v_conv1d(
186
+ x=self.v_proj(hidden_states),
187
+ mask=conv_mask,
188
+ cache=conv_state_v,
189
+ output_final_state=use_cache,
190
+ cu_seqlens=cu_seqlens
191
+ )
192
+ else:
193
+ q = self.q_proj(hidden_states)
194
+ k = self.k_proj(hidden_states)
195
+ v = self.v_proj(hidden_states)
196
+ gk = self.gk_proj(hidden_states)
197
+
198
+ if self.feature_map_fn is not None:
199
+ q, k = map(self.feature_map_fn, (q, k))
200
+ # dealing with left-padding
201
+ if attention_mask is not None:
202
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
203
+ q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
204
+ if self.num_kv_groups > 1:
205
+ k, v = (repeat(x, '... (h d) -> ... (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v))
206
+ else:
207
+ k, v = (rearrange(x, '... (h d) -> ... h d', h=self.num_kv_heads) for x in (k, v))
208
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
209
+
210
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
211
+ if mode == 'chunk':
212
+ o, recurrent_state = chunk_simple_gla(
213
+ q=q,
214
+ k=k,
215
+ v=v,
216
+ gk=gk,
217
+ initial_state=recurrent_state,
218
+ output_final_state=use_cache,
219
+ cu_seqlens=cu_seqlens,
220
+ head_first=False
221
+ )
222
+ elif mode == 'fused_recurrent':
223
+ o, recurrent_state = fused_recurrent_simple_gla(
224
+ q=q,
225
+ k=k,
226
+ v=v,
227
+ gk=gk,
228
+ initial_state=recurrent_state,
229
+ output_final_state=use_cache,
230
+ cu_seqlens=cu_seqlens,
231
+ head_first=False
232
+ )
233
+ else:
234
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
235
+
236
+ if past_key_values is not None:
237
+ past_key_values.update(
238
+ recurrent_state=recurrent_state,
239
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
240
+ layer_idx=self.layer_idx,
241
+ offset=q.shape[1]
242
+ )
243
+
244
+ g = self.g_proj(hidden_states)
245
+ if self.fuse_norm_and_gate:
246
+ g = rearrange(g, 'b t (h d) -> b t h d', h=self.num_heads)
247
+ o = self.g_norm_swish_gate(o, g)
248
+ o = rearrange(o, 'b t h d -> b t (h d)')
249
+ else:
250
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
251
+ o = o * self.gate_fn(g)
252
+ o = self.o_proj(o)
253
+
254
+ return o, None, past_key_values
255
+
256
+ def state_size(self, **kwargs) -> int:
257
+ state_size = self.key_dim * self.head_v_dim
258
+ for module in self.children():
259
+ if isinstance(module, ShortConvolution):
260
+ state_size += module.state_size
261
+ return state_size
fla/modules/__pycache__/rotary.cpython-312.pyc ADDED
Binary file (23.2 kB). View file
 
fla/modules/parallel.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch.nn as nn
7
+ from torch.distributed import DeviceMesh
8
+ from torch.distributed.tensor import DTensor, distribute_module
9
+ from torch.distributed.tensor.parallel import ParallelStyle
10
+ from torch.distributed.tensor.placement_types import Placement
11
+
12
+
13
+ class PrepareModuleWeight(ParallelStyle):
14
+ def __init__(self, *, layouts: Optional[Placement] = None):
15
+ super().__init__()
16
+ self.layouts = layouts
17
+
18
+ def _replicate_module_fn(
19
+ self,
20
+ name: str,
21
+ module: nn.Module,
22
+ device_mesh: DeviceMesh
23
+ ):
24
+ for p_name, param in module.named_parameters():
25
+ replicated_param = nn.Parameter(
26
+ DTensor.from_local(param, device_mesh, [self.layouts], run_check=False)
27
+ )
28
+ module.register_parameter(p_name, replicated_param)
29
+
30
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
31
+ return distribute_module(
32
+ module,
33
+ device_mesh,
34
+ partition_fn=self._replicate_module_fn,
35
+ input_fn=None,
36
+ output_fn=None
37
+ )
flame/__pycache__/config_manager.cpython-312.pyc ADDED
Binary file (36.9 kB). View file
 
flame/components/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (3.2 kB). View file
 
flame/models/__init__.py ADDED
File without changes
flame/models/fla.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [model]
2
+ config = "fla-hub/transformer-1.3B-100B"
3
+ tokenizer_path = "fla-hub/transformer-1.3B-100B"
4
+
5
+ [job]
6
+ dump_folder = "exp"
7
+ print_args = true
8
+
9
+ [training]
10
+ batch_size = 32
11
+ seq_len = 2048
12
+ context_len = 2048
13
+ gradient_accumulation_steps = 1
14
+ steps = 20480
15
+ max_norm = 1.0
16
+ skip_nan_inf = true
17
+ data_parallel_replicate_degree = 1
18
+ data_parallel_shard_degree = -1
19
+ tensor_parallel_degree = 1
20
+ compile = false
21
+ dataset = "HuggingFaceFW/fineweb-edu"
22
+ dataset_name = "default"
23
+ num_workers = 32
24
+ pin_memory = false
25
+ persistent_workers = false
26
+ prefetch_factor = 2
27
+ seed = 42
28
+ varlen = false
29
+
30
+ [optimizer]
31
+ name = "AdamW"
32
+ eps = 1e-15
33
+ lr = 3e-4
34
+
35
+ [lr_scheduler]
36
+ warmup_steps = 1024
37
+ decay_type = "cosine"
38
+ lr_min = 0.1
39
+
40
+ [checkpoint]
41
+ enable_checkpoint = true
42
+ folder = "checkpoint"
43
+ interval_type = "steps"
44
+ interval = 2048
45
+ model_weights_only = false
46
+ export_dtype = "float32"
47
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
48
+
49
+ [profiling]
50
+ enable_profiling = true
51
+ save_traces_folder = "profile_trace"
52
+ profile_freq = 512
53
+
54
+ [metrics]
55
+ log_freq = 32
56
+ enable_wandb = true
57
+
58
+ [experimental]
59
+ context_parallel_degree = 1
60
+ pipeline_parallel_degree = 1
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+
66
+ [activation_checkpoint]
67
+ mode = "none"
flame/models/parallelize_fla.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
16
+ from torch.distributed._composable.replicate import replicate
17
+ from torch.distributed._tensor import Replicate, Shard
18
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
19
+ from torch.distributed.tensor.parallel import (
20
+ ColwiseParallel,
21
+ PrepareModuleInput,
22
+ PrepareModuleOutput,
23
+ RowwiseParallel,
24
+ SequenceParallel,
25
+ parallelize_module
26
+ )
27
+
28
+ from fla.modules.fused_linear_cross_entropy import LinearLossParallel
29
+ from fla.modules.mlp import SwiGLULinearParallel
30
+ from fla.modules.parallel import PrepareModuleWeight
31
+ from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
32
+ from torchtitan.distributed.parallel_dims import ParallelDims
33
+ from torchtitan.tools.logging import logger
34
+
35
+
36
+ def parallelize_fla(
37
+ model: nn.Module,
38
+ world_mesh: DeviceMesh,
39
+ parallel_dims: ParallelDims,
40
+ job_config: JobConfig,
41
+ ):
42
+ """
43
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
44
+ parallelism to the model.
45
+
46
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
47
+ the model must fit on GPU or CPU memory.
48
+ """
49
+
50
+ if parallel_dims.tp_enabled:
51
+ if (
52
+ job_config.experimental.enable_async_tensor_parallel
53
+ and not job_config.training.compile
54
+ ):
55
+ raise RuntimeError("Async TP requires --training.compile")
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8=enable_float8_linear,
62
+ enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
63
+ )
64
+
65
+ if job_config.activation_checkpoint.mode != "none":
66
+ apply_ac(model, job_config.activation_checkpoint)
67
+
68
+ # turn on per-block compile after AC wrapping and before FSDP
69
+ if job_config.training.compile:
70
+ apply_compile(model)
71
+
72
+ if (
73
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
74
+ ): # apply FSDP or HSDP, potentially with Context Parallel
75
+ if parallel_dims.dp_replicate_enabled:
76
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
77
+ else:
78
+ dp_mesh_dim_names = ("dp_shard_cp",)
79
+
80
+ apply_fsdp(
81
+ model,
82
+ world_mesh[tuple(dp_mesh_dim_names)],
83
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
84
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
85
+ pp_enabled=parallel_dims.pp_enabled,
86
+ cpu_offload=job_config.training.enable_cpu_offload,
87
+ reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
88
+ )
89
+
90
+ if parallel_dims.dp_replicate_enabled:
91
+ logger.info("Applied HSDP to the model")
92
+ else:
93
+ logger.info("Applied FSDP to the model")
94
+
95
+ if parallel_dims.cp_enabled:
96
+ logger.info("Applied Context Parallel to the model")
97
+
98
+ if job_config.training.enable_cpu_offload:
99
+ logger.info("Applied CPU Offloading to the model")
100
+ elif parallel_dims.dp_replicate_enabled:
101
+ if world_mesh.ndim > 1:
102
+ raise RuntimeError("DDP has not supported > 1D parallelism")
103
+ apply_ddp(
104
+ model,
105
+ world_mesh,
106
+ enable_compile=job_config.training.compile,
107
+ enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
108
+ )
109
+
110
+
111
+ class TPPlan:
112
+ def __init__(
113
+ self,
114
+ model=None,
115
+ loss_parallel=False,
116
+ enable_float8=False,
117
+ ):
118
+ self.model = model
119
+ self.loss_parallel = loss_parallel
120
+ self.enable_float8 = enable_float8
121
+ self.base_model_prefix = getattr(model, "base_model_prefix", "model")
122
+
123
+ # TODO(vkuzo): once float8 configuration supports delayed scaling,
124
+ # add a check here to enforce supported float8 all-gather configurations
125
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
126
+ try:
127
+ from torchao.float8.float8_tensor_parallel import (
128
+ Float8ColwiseParallel,
129
+ Float8RowwiseParallel,
130
+ PrepareFloat8ModuleInput
131
+ )
132
+ except ImportError:
133
+ Float8ColwiseParallel = None
134
+ Float8RowwiseParallel = None
135
+ PrepareFloat8ModuleInput = None
136
+ if self.enable_float8 and Float8ColwiseParallel is not None:
137
+ self.rowwise_parallel = Float8RowwiseParallel
138
+ self.colwise_parallel = Float8ColwiseParallel
139
+ self.prepare_module_input = PrepareFloat8ModuleInput
140
+ self.prepare_module_output = PrepareModuleOutput
141
+ else:
142
+ self.rowwise_parallel = RowwiseParallel
143
+ self.colwise_parallel = ColwiseParallel
144
+ self.prepare_module_input = PrepareModuleInput
145
+ self.prepare_module_output = PrepareModuleOutput
146
+
147
+ @property
148
+ def model_plan(self):
149
+ plans = {
150
+ f"{self.base_model_prefix}.embeddings": RowwiseParallel(
151
+ input_layouts=Replicate(),
152
+ output_layouts=Shard(1),
153
+ ),
154
+ f"{self.base_model_prefix}.norm": SequenceParallel(),
155
+ }
156
+ if self.loss_parallel:
157
+ plans.update(
158
+ {
159
+ "lm_head": ColwiseParallel(
160
+ input_layouts=Shard(1),
161
+ output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
162
+ use_local_output=not self.loss_parallel,
163
+ ),
164
+ }
165
+ )
166
+ else:
167
+ plans.update(
168
+ {
169
+ "lm_head": PrepareModuleWeight(layouts=Replicate()),
170
+ "criterion": LinearLossParallel(),
171
+ }
172
+ )
173
+ return plans
174
+
175
+ @property
176
+ def layer_plan(self):
177
+ return {
178
+ "attn_norm": SequenceParallel(),
179
+ **self.attn_plan,
180
+ "mlp_norm": SequenceParallel(),
181
+ **self.mlp_plan,
182
+ }
183
+
184
+ @property
185
+ def attn_plan(self):
186
+ raise NotImplementedError(
187
+ f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
188
+ )
189
+
190
+ @property
191
+ def mlp_plan(self):
192
+ return {
193
+ "mlp": self.prepare_module_input(
194
+ input_layouts=(Shard(1),),
195
+ desired_input_layouts=(Replicate(),),
196
+ ),
197
+ "mlp.gate_proj": self.colwise_parallel(),
198
+ "mlp.up_proj": self.colwise_parallel(),
199
+ "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
200
+ "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
201
+ }
202
+
203
+
204
+ class TransformerTPPlan(TPPlan):
205
+
206
+ @property
207
+ def attn_plan(self):
208
+ return {
209
+ "attn": self.prepare_module_input(
210
+ input_kwarg_layouts={"hidden_states": Shard(1)},
211
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
212
+ ),
213
+ "attn.q_proj": self.colwise_parallel(),
214
+ "attn.k_proj": self.colwise_parallel(),
215
+ "attn.v_proj": self.colwise_parallel(),
216
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
217
+ }
218
+
219
+
220
+ class GLATPPlan(TPPlan):
221
+
222
+ @property
223
+ def attn_plan(self):
224
+ return {
225
+ "attn": self.prepare_module_input(
226
+ input_kwarg_layouts={"hidden_states": Shard(1)},
227
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
228
+ ),
229
+ "attn.q_proj": self.colwise_parallel(),
230
+ "attn.k_proj": self.colwise_parallel(),
231
+ "attn.v_proj": self.colwise_parallel(),
232
+ "attn.g_proj": self.colwise_parallel(),
233
+ "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
234
+ "attn.gk_proj.1": self.colwise_parallel(),
235
+ "attn.g_norm": SequenceParallel(sequence_dim=-1),
236
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
237
+ }
238
+
239
+
240
+ TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
241
+
242
+
243
+ def apply_tp(
244
+ model: nn.Module,
245
+ tp_mesh: DeviceMesh,
246
+ loss_parallel: bool,
247
+ enable_float8: bool,
248
+ enable_async_tp: bool,
249
+ ):
250
+ """Apply tensor parallelism."""
251
+ # 1. Parallelize the embedding and shard its outputs (which are the first
252
+ # transformer block's inputs)
253
+ # 2. Parallelize the root norm layer over the sequence dim
254
+ # 3. Parallelize the final linear output layer
255
+ tp_plan = TP_PLAN_MAP[model.config.model_type](
256
+ model, loss_parallel=loss_parallel, enable_float8=enable_float8
257
+ )
258
+ parallelize_module(model, tp_mesh, tp_plan.model_plan)
259
+
260
+ blocks = get_blocks(model)
261
+ if blocks is None:
262
+ logger.warning("No block found for tensor parallelism")
263
+ else:
264
+ for _, block in enumerate(blocks):
265
+ parallelize_module(
266
+ module=block,
267
+ device_mesh=tp_mesh,
268
+ parallelize_plan=tp_plan.layer_plan,
269
+ )
270
+
271
+ if enable_async_tp:
272
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
273
+
274
+ torch._inductor.config._micro_pipeline_tp = True
275
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
276
+
277
+ logger.info(
278
+ f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
279
+ "Tensor Parallelism to the model"
280
+ )
281
+
282
+
283
+ # for selective op activation checkpointing
284
+ _save_list = {
285
+ torch.ops.aten.mm.default,
286
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
287
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
288
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
289
+ # for low precision training, it's useful to always save
290
+ # the result of max, since the absolute maximum is
291
+ # used to compute the scaling factor for quantization.
292
+ torch.ops.aten.max.default,
293
+ }
294
+
295
+
296
+ def _apply_ac_to_block(module: nn.Module, ac_config):
297
+ valid_ac_modes = ("full", "selective")
298
+ if ac_config.mode not in valid_ac_modes:
299
+ raise ValueError(
300
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
301
+ )
302
+
303
+ if ac_config.mode == "full":
304
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
305
+
306
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
307
+ use_op_sac = ac_config.selective_ac_option == "op"
308
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
309
+ if not use_op_sac and not use_layer_sac:
310
+ raise ValueError(
311
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
312
+ f"Valid options: 'op' or a positive int representing layer frequency"
313
+ )
314
+ if use_op_sac:
315
+ from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
316
+
317
+ def _get_custom_policy(meta):
318
+ def _custom_policy(ctx, func, *args, **kwargs):
319
+ mode = "recompute" if ctx.is_recompute else "forward"
320
+ mm_count_key = f"{mode}_mm_count"
321
+ if func == torch.ops.aten.mm.default:
322
+ meta[mm_count_key] += 1
323
+ # Saves output of all compute ops, except every second mm
324
+ to_save = func in _save_list and not (
325
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
326
+ )
327
+ return (
328
+ CheckpointPolicy.MUST_SAVE
329
+ if to_save
330
+ else CheckpointPolicy.PREFER_RECOMPUTE
331
+ )
332
+
333
+ return _custom_policy
334
+
335
+ def selective_checkpointing_context_fn():
336
+ meta = defaultdict(int)
337
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
338
+
339
+ return ptd_checkpoint_wrapper(
340
+ module,
341
+ context_fn=selective_checkpointing_context_fn,
342
+ preserve_rng_state=False,
343
+ )
344
+ elif use_layer_sac:
345
+ # Checkpoint every `ac_freq` of the modules passed to this function
346
+ ac_freq = int(ac_config.selective_ac_option)
347
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
348
+ ptd_checkpoint_wrapper._count += 1
349
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
350
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
351
+ else:
352
+ return module
353
+
354
+
355
+ def apply_ac(model: nn.Module, ac_config):
356
+ """Apply activation checkpointing to the model."""
357
+ blocks = get_blocks(model)
358
+ if blocks is None:
359
+ logger.warning("No block found for activation checkpointing")
360
+ return
361
+
362
+ for layer_id, block in blocks.named_children():
363
+ block = _apply_ac_to_block(block, ac_config)
364
+ blocks.register_module(layer_id, block)
365
+
366
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
367
+
368
+
369
+ def apply_compile(model: nn.Module):
370
+ """
371
+ Apply torch.compile to each block, which makes compilation efficient due to
372
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
373
+ """
374
+
375
+ blocks = get_blocks(model)
376
+ if blocks is None:
377
+ logger.warning("No block found for torch.compile")
378
+ else:
379
+ for layer_id, block in blocks.named_children():
380
+ block = torch.compile(block)
381
+ blocks.register_module(layer_id, block)
382
+ logger.info("Compiling each block with torch.compile")
383
+
384
+ real_model = get_model(model)
385
+
386
+ logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
387
+ embeddings_key = get_components_name(real_model, "tok_embeddings")
388
+ if embeddings_key is not None:
389
+ embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
390
+ real_model.register_module(embeddings_key, embeddings)
391
+
392
+ norm_key = get_components_name(real_model, "norm")
393
+ if norm_key is not None:
394
+ norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
395
+ real_model.register_module(norm_key, norm)
396
+
397
+ lm_head_key = get_components_name(model, "lm_head")
398
+ if lm_head_key is not None:
399
+ lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
400
+ model.register_module(lm_head_key, lm_head)
401
+
402
+ logger.info("Compiling the entire model with torch.compile")
403
+ model = torch.compile(model)
404
+
405
+
406
+ def apply_fsdp(
407
+ model: nn.Module,
408
+ dp_mesh: DeviceMesh,
409
+ param_dtype: torch.dtype,
410
+ reduce_dtype: torch.dtype,
411
+ pp_enabled: bool,
412
+ cpu_offload: bool = False,
413
+ reshard_after_forward_policy: str = "default",
414
+ ):
415
+ """
416
+ Apply data parallelism (via FSDP2) to the model.
417
+
418
+ Args:
419
+ model (nn.Module): The model to apply data parallelism to.
420
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
421
+ param_dtype (torch.dtype): The data type to use for model parameters.
422
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
423
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
424
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
425
+ reshard_after_forward_policy (str, optional):
426
+ The policy to use for resharding after forward pass. Defaults to "default".
427
+ Other options: "never", "always".
428
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
429
+ - "always" will enable `reshard_after_forward` for all forward passes.
430
+ - "never" will disable `reshard_after_forward` for all forward passes.
431
+
432
+ """
433
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
434
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
435
+ if cpu_offload:
436
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
437
+
438
+ blocks = get_blocks(model)
439
+ if blocks is None:
440
+ logger.warning("No block found for FSDP")
441
+ else:
442
+ total_blocks = len(blocks)
443
+ for layer_id, block in enumerate(blocks):
444
+ if reshard_after_forward_policy == "always":
445
+ reshard_after_forward = True
446
+ elif reshard_after_forward_policy == "never":
447
+ reshard_after_forward = False
448
+ elif reshard_after_forward_policy == "default":
449
+ if pp_enabled:
450
+ # For PP, do not reshard after forward to avoid per-microbatch
451
+ # all-gathers, which can be expensive and non-overlapped
452
+ reshard_after_forward = False
453
+ else:
454
+ # As an optimization, do not reshard after forward for the last
455
+ # transformer block since FSDP would prefetch it immediately
456
+ reshard_after_forward = int(layer_id) < total_blocks - 1
457
+ else:
458
+ raise ValueError(
459
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
460
+ )
461
+ fully_shard(
462
+ block,
463
+ **fsdp_config,
464
+ reshard_after_forward=reshard_after_forward,
465
+ )
466
+
467
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
468
+
469
+
470
+ def apply_ddp(
471
+ model: nn.Module,
472
+ dp_mesh: DeviceMesh,
473
+ enable_compile: bool,
474
+ enable_compiled_autograd: bool,
475
+ ):
476
+ if enable_compile:
477
+ if enable_compiled_autograd:
478
+ torch._dynamo.config.optimize_ddp = (
479
+ "python_reducer_without_compiled_forward"
480
+ )
481
+ else:
482
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
483
+
484
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
485
+
486
+ logger.info("Applied DDP to the model")
487
+
488
+
489
+ def get_model(model):
490
+ base_model_prefix = getattr(model, "base_model_prefix", "model")
491
+ if not hasattr(model, base_model_prefix):
492
+ return None
493
+ model = getattr(model, base_model_prefix)
494
+ return model
495
+
496
+
497
+ def get_blocks(model):
498
+ # TODO[flame]: adapt for network not using 'layers' attribute
499
+ model = get_model(model)
500
+ if not hasattr(model, "layers"):
501
+ logger.warning('no "layers" in model can be found')
502
+ return None
503
+ return model.layers
504
+
505
+
506
+ def get_components_name(model, component_name):
507
+ """
508
+ We try to catch tok_embeddings, norm layers and lm_head layers
509
+ We do not catch the layer names in the blocks, for blocks see `get_blocks`
510
+ We assume the model has the following structure:
511
+ LlamaForCausalLM:
512
+ Model:
513
+ embed_tokens,
514
+ layers,
515
+ norm,
516
+ lm_head
517
+ ***
518
+ so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
519
+ and for 'lm_head' we need to pass `model`
520
+ ***
521
+ """
522
+
523
+ if component_name == "tok_embeddings":
524
+ if hasattr(model, "tok_embeddings"):
525
+ return "tok_embeddings"
526
+ elif hasattr(model, "embed_tokens"):
527
+ return "embed_tokens"
528
+ elif hasattr(model, "embeddings"):
529
+ return "embeddings"
530
+ else:
531
+ logger.warning("No tok_embeddings found in model")
532
+ return None
533
+
534
+ elif component_name == "norm":
535
+ if hasattr(model, "norm"):
536
+ return "norm"
537
+ elif hasattr(model, "norms"):
538
+ return "norms"
539
+ elif hasattr(model, "layernorm"):
540
+ return "layernorm"
541
+ else:
542
+ logger.warning("No norm found in model")
543
+ return None
544
+
545
+ elif component_name == "lm_head":
546
+ if hasattr(model, "lm_head"):
547
+ return "lm_head"
548
+ else:
549
+ logger.warning("No lm_head found in model")
550
+ return None
flame/tools/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (127 Bytes). View file
 
flame/tools/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.14 kB). View file
 
flame/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (127 Bytes). View file
 
flame/utils/__pycache__/checkpoint.cpython-312.pyc ADDED
Binary file (4.06 kB). View file
 
flame/utils/__pycache__/convert_dcp_to_hf.cpython-312.pyc ADDED
Binary file (3.72 kB). View file
 
flame/utils/__pycache__/hf_utils.cpython-312.pyc ADDED
Binary file (4.45 kB). View file
 
flame/utils/convert_hf_to_dcp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.distributed.checkpoint as DCP
9
+ from transformers import AutoModelForCausalLM
10
+
11
+ import fla # noqa
12
+ from torchtitan.tools.logging import init_logger, logger
13
+
14
+
15
+ @torch.inference_mode()
16
+ def convert_hf_weights(model: str, checkpoint: str):
17
+ logger.info(f"Loading model from {model}")
18
+ model = AutoModelForCausalLM.from_pretrained(model)
19
+ state_dict = model.state_dict()
20
+
21
+ logger.info(f"Writing to DCP at '{checkpoint}'")
22
+ checkpoint.mkdir(parents=True, exist_ok=True)
23
+ storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
24
+ DCP.save({"model": state_dict}, storage_writer=storage_writer)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ init_logger()
29
+ parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
30
+ parser.add_argument("--model", type=str, required=True)
31
+ parser.add_argument("--checkpoint", type=Path, required=True)
32
+ args = parser.parse_args()
33
+
34
+ convert_hf_weights(args.model, args.checkpoint)
logs/none_99omtdbz/attempt_0/0/stdout.log ADDED
File without changes
logs/none_99omtdbz/attempt_0/3/stdout.log ADDED
File without changes
logs/none_99omtdbz/attempt_0/4/stdout.log ADDED
File without changes
logs/none_99omtdbz/attempt_0/7/stdout.log ADDED
File without changes
profile_trace/iteration_10752/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1536/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1536/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1536/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1536/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_2048/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_2048/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_25088/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_25088/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_25088/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_25088/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_31232/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_31232/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_34304/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_34304/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_34304/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_34304/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_5120/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_5120/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_9728/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_9728/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_9728/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff