YongganFu commited on
Commit
8609b94
·
verified ·
1 Parent(s): 5f71aa0

Upload NemotronFlashForCausalLM

Browse files
config.json CHANGED
@@ -1,14 +1,14 @@
1
  {
2
  "architectures": [
3
- "FastSLMForCausalLM"
4
  ],
5
  "attention_dropout": 0.0,
6
  "attn_hidden_size": -1,
7
- "attn_implementation": "flash_attention_2",
8
- "attn_implementation_new": "flash_attention_2",
9
  "auto_map": {
10
- "AutoConfig": "configuration_fast_slm.FastSLMConfig",
11
- "AutoModelForCausalLM": "modeling_fast_slm.FastSLMForCausalLM"
12
  },
13
  "bos_token_id": 1,
14
  "calc_logits_for_entire_prompt": false,
 
1
  {
2
  "architectures": [
3
+ "NemotronFlashForCausalLM"
4
  ],
5
  "attention_dropout": 0.0,
6
  "attn_hidden_size": -1,
7
+ "attn_implementation": "fused_mha",
8
+ "attn_implementation_new": "fused_mha",
9
  "auto_map": {
10
+ "AutoConfig": "configuration_nemotron_flash.NemotronFlashConfig",
11
+ "AutoModelForCausalLM": "modeling_nemotron_flash.NemotronFlashForCausalLM"
12
  },
13
  "bos_token_id": 1,
14
  "calc_logits_for_entire_prompt": false,
configuration_nemotron_flash.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Nemotron Flash model configuration"""
16
+ import math
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class NemotronFlashConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
28
+ Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the jamba-small architecture.
30
+
31
+ [ai21labs/jamba-small](https://huggingface.co/ai21labs/Jamba-v0.1)
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 65536):
39
+ Vocabulary size of the Jamba model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`JambaModel`]
41
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
42
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
43
+ model has a output word embedding layer.
44
+ hidden_size (`int`, *optional*, defaults to 4096):
45
+ Dimension of the hidden representations.
46
+ intermediate_size (`int`, *optional*, defaults to 14336):
47
+ Dimension of the MLP representations.
48
+ num_hidden_layers (`int`, *optional*, defaults to 32):
49
+ Number of hidden layers in the Transformer encoder.
50
+ num_attention_heads (`int`, *optional*, defaults to 32):
51
+ Number of attention heads for each attention layer in the Transformer encoder.
52
+ num_key_value_heads (`int`, *optional*, defaults to 8):
53
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
54
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
55
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
56
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
57
+ by meanpooling all the original heads within that group. For more details checkout [this
58
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64
+ The epsilon used by the rms normalization layers.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ calc_logits_for_entire_prompt (`bool`, *optional*, defaults to `False`):
69
+ Whether or not to calculate logits for entire prompt during generation. If `False`, only the logits of the
70
+ last prompt token will be calculated, which are the only logits needed for generation. For long sequences,
71
+ the logits for the entire sequence may use a lot of memory so setting `calc_logits_for_entire_prompt=False`
72
+ will reduce memory footprint significantly.
73
+ Note: some generation features may not be available if this is set to `False`.
74
+ output_router_logits (`bool`, *optional*, defaults to `False`):
75
+ Whether or not the router logits should be returned by the model. Enabling this will also
76
+ allow the model to output the auxiliary loss. See [here]() for more details
77
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
78
+ The aux loss factor for the total loss.
79
+ pad_token_id (`int`, *optional*, defaults to 0):
80
+ The id of the padding token.
81
+ bos_token_id (`int`, *optional*, defaults to 1):
82
+ The id of the "beginning-of-sequence" token.
83
+ eos_token_id (`int`, *optional*, defaults to 2):
84
+ The id of the "end-of-sequence" token.
85
+ sliding_window (`int`, *optional*):
86
+ Sliding window attention window size. If not specified, will default to `None`.
87
+ n_ctx (`int`, *optional*, defaults to 262144):
88
+ This value doesn't have any real effect. The maximum sequence length that this model is intended to be
89
+ used with. It can be used with longer sequences, but performance may degrade.
90
+ attention_dropout (`float`, *optional*, defaults to 0.0):
91
+ The dropout ratio for the attention probabilities.
92
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
93
+ The number of experts to root per-token, can be also interpreted as the `top-p` routing
94
+ parameter
95
+ num_experts (`int`, *optional*, defaults to 16):
96
+ Number of experts per Sparse MLP layer.
97
+ use_mamba_kernels (`bool`, *optional*, defaults to `True`):
98
+ Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
99
+ `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if
100
+ `True` and kernels are not available
101
+ mamba_d_state (`int`, *optional*, defaults to 16):
102
+ The dimension the mamba state space latents
103
+ mamba_d_conv (`int`, *optional*, defaults to 4):
104
+ The size of the mamba convolution kernel
105
+ mamba_expand (`int`, *optional*, defaults to 2):
106
+ Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
107
+ mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
108
+ Rank of the the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
109
+ mamba_conv_bias (`bool`, *optional*, defaults to `True`):
110
+ Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
111
+ mamba_proj_bias (`bool`, *optional*, defaults to `False`):
112
+ Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
113
+ mamba_inner_layernorms (`bool`, *optional*, defaults to `True`):
114
+ Flag indicating whether or not to apply layernorms to internal mamba activations
115
+
116
+ """
117
+
118
+ model_type = "jamba"
119
+ keys_to_ignore_at_inference = ["past_key_values"]
120
+
121
+ def __init__(
122
+ self,
123
+ vocab_size=65536,
124
+ tie_word_embeddings=False,
125
+ hidden_size=4096,
126
+ intermediate_size=14336,
127
+ num_hidden_layers=32,
128
+ num_attention_heads=32,
129
+ num_key_value_heads=8,
130
+ hidden_act="silu",
131
+ initializer_range=0.02,
132
+ rms_norm_eps=1e-6,
133
+ use_cache=True,
134
+ calc_logits_for_entire_prompt=False,
135
+ output_router_logits=False,
136
+ router_aux_loss_coef=0.001,
137
+ pad_token_id=0,
138
+ bos_token_id=1,
139
+ eos_token_id=2,
140
+ sliding_window=None,
141
+ max_position_embeddings=262144,
142
+ orig_max_position_embeddings=None,
143
+ attention_dropout=0.0,
144
+ num_experts_per_tok=2,
145
+ num_experts=16,
146
+ use_mamba_kernels=True,
147
+ mamba_d_state=16,
148
+ mamba_d_conv=4,
149
+ mamba_expand=2,
150
+ mamba_dt_rank="auto",
151
+ mamba_conv_bias=True,
152
+ mamba_proj_bias=False,
153
+ mamba_inner_layernorms=True,
154
+
155
+ hybrid_decoder_layer='mamba',
156
+
157
+ global_attn_idx=None,
158
+
159
+ attn_implementation_new='flash_attention_2',
160
+
161
+ mamba2_headdim=64,
162
+
163
+ rope_type=None,
164
+
165
+ layer_types=None,
166
+
167
+ ffn_expand_ratio=None,
168
+
169
+ d_conv=4,
170
+
171
+ **kwargs,
172
+ ):
173
+ self.vocab_size = vocab_size
174
+ self.tie_word_embeddings = tie_word_embeddings
175
+ self.hidden_size = hidden_size
176
+ self.intermediate_size = intermediate_size
177
+ self.num_hidden_layers = num_hidden_layers
178
+ self.num_attention_heads = num_attention_heads
179
+ self.sliding_window = sliding_window
180
+ self.max_position_embeddings = max_position_embeddings
181
+ self.orig_max_position_embeddings = orig_max_position_embeddings
182
+ self.attention_dropout = attention_dropout
183
+
184
+ # for backward compatibility
185
+ if num_key_value_heads is None:
186
+ num_key_value_heads = num_attention_heads
187
+
188
+ self.num_key_value_heads = num_key_value_heads
189
+ self.hidden_act = hidden_act
190
+ self.initializer_range = initializer_range
191
+ self.rms_norm_eps = rms_norm_eps
192
+
193
+ self.use_cache = use_cache
194
+ self.calc_logits_for_entire_prompt = calc_logits_for_entire_prompt
195
+ self.output_router_logits = output_router_logits
196
+ self.router_aux_loss_coef = router_aux_loss_coef
197
+
198
+ self.num_experts_per_tok = num_experts_per_tok
199
+ self.num_experts = num_experts
200
+
201
+ self.use_mamba_kernels = use_mamba_kernels
202
+ self.mamba_d_state = mamba_d_state
203
+ self.mamba_d_conv = mamba_d_conv
204
+ self.mamba_expand = mamba_expand
205
+ self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
206
+ self.mamba_conv_bias = mamba_conv_bias
207
+ self.mamba_proj_bias = mamba_proj_bias
208
+ self.mamba_inner_layernorms = mamba_inner_layernorms
209
+
210
+ # added by Xin
211
+ self.kq_norm = kwargs.pop("kq_norm", None)
212
+ self.rope = kwargs.pop("rope", False)
213
+ self.rope_theta = kwargs.pop("rope_theta", 10000.0)
214
+ self.num_memory_tokens = kwargs.pop("num_memory_tokens", 0)
215
+ self.attn_hidden_size = kwargs.pop("attn_hidden_size", -1)
216
+ self.kq_head_dim = kwargs.pop("kq_head_dim", -1)
217
+ self.v_head_dim = kwargs.pop("v_head_dim", -1)
218
+
219
+ #! adhoc change
220
+ self.new_seq_length = 2048
221
+
222
+ self.hybrid_decoder_layer = hybrid_decoder_layer
223
+
224
+ self.global_attn_idx = global_attn_idx
225
+
226
+ self.attn_implementation_new = attn_implementation_new
227
+
228
+ self.mamba2_headdim = mamba2_headdim
229
+
230
+ self.rope_type = rope_type
231
+
232
+ self.layer_types = layer_types
233
+
234
+ self.ffn_expand_ratio = ffn_expand_ratio
235
+
236
+ self.d_conv = d_conv
237
+
238
+ self.mlp_hidden_act = kwargs.pop("mlp_hidden_act", "silu")
239
+
240
+ super().__init__(
241
+ pad_token_id=pad_token_id,
242
+ bos_token_id=bos_token_id,
243
+ eos_token_id=eos_token_id,
244
+ tie_word_embeddings=tie_word_embeddings,
245
+ **kwargs,
246
+ )
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.56.2",
7
+ "use_cache": false
8
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4fdb6a7f8e726e8e00f5f63dde79d18b499722f6ce7199027679cd01c9f71a87
3
- size 1930804728
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b15235dc83bb411387d9a15694e02d3697051e303b067336e17c302a17b6125d
3
+ size 1930804368
modeling_nemotron_flash.py ADDED
@@ -0,0 +1,2023 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 NVIDIA Corporation. All rights reserved.
3
+
4
+ """ PyTorch Nemotron-Flash model."""
5
+ import inspect
6
+ import math
7
+ import copy
8
+ import warnings
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+ import time
11
+ import numpy as np
12
+ import os
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import torch.utils.checkpoint
17
+ from torch import nn
18
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
19
+
20
+ torch._inductor.config.max_autotune_gemm_backends = ["aten"]
21
+
22
+ from transformers.activations import ACT2FN
23
+ from transformers.cache_utils import Cache, DynamicCache
24
+ from transformers.modeling_outputs import (
25
+ MoeCausalLMOutputWithPast,
26
+ MoeModelOutputWithPast,
27
+ )
28
+ from transformers.modeling_utils import PreTrainedModel
29
+ from transformers.generation import GenerationMixin
30
+
31
+ try:
32
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
33
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
34
+ except ImportError:
35
+ pass
36
+
37
+ from transformers.utils import (
38
+ is_flash_attn_greater_or_equal_2_10,
39
+ logging,
40
+ replace_return_docstrings,
41
+ )
42
+ from .configuration_nemotron_flash import NemotronFlashConfig
43
+
44
+ import math
45
+
46
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
47
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
48
+
49
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
50
+
51
+ from einops import rearrange, repeat, reduce, pack, unpack
52
+
53
+ from .fused_mha_with_cache import fused_mha_interface
54
+
55
+ from .mamba2 import Mamba2
56
+ from mamba_ssm.utils.generation import InferenceParams
57
+ from .delta_net import Cache as fla_cache
58
+ from .delta_net import DeltaNet
59
+ import torch._dynamo
60
+ torch._dynamo.config.suppress_errors = True
61
+
62
+ from torch.cuda import CUDAGraph
63
+
64
+ logger = logging.get_logger(__name__)
65
+
66
+ _CONFIG_FOR_DOC = "NemotronFlashConfig"
67
+
68
+
69
+ class NemotronFlashRMSNorm(nn.Module):
70
+
71
+ def __init__(self, hidden_size, learnable_weight=True, eps=1e-6):
72
+ super().__init__()
73
+ if learnable_weight:
74
+ self.weight = nn.Parameter(torch.ones(hidden_size))
75
+ else:
76
+ self.weight = None
77
+ self.variance_epsilon = eps
78
+
79
+ def forward(self, hidden_states):
80
+ input_dtype = hidden_states.dtype
81
+ hidden_states = hidden_states.to(torch.float32)
82
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
83
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
84
+
85
+ if self.weight is not None:
86
+ return self.weight * hidden_states.to(input_dtype)
87
+ else:
88
+ return hidden_states.to(input_dtype)
89
+
90
+ class LlamaRotaryEmbedding(nn.Module):
91
+ def __init__(self, config, dim, base=10000, device=None, scaling_factor=1.0):
92
+ super().__init__()
93
+ self.scaling_factor = scaling_factor
94
+ self.dim = dim
95
+ self.base = base
96
+ self.config = config
97
+
98
+ self.rope_type = config.rope_type
99
+
100
+ self.factor = 2
101
+
102
+ max_position_embeddings = self.config.max_position_embeddings
103
+
104
+ if config.rope_type is None or config.rope_type == "default":
105
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
106
+ self.max_seq_len_cached = max_position_embeddings
107
+
108
+ elif config.rope_type == 'ntk':
109
+ assert self.config.orig_max_position_embeddings is not None
110
+ orig_max_position_embeddings = self.config.orig_max_position_embeddings
111
+
112
+ base = base * ((self.factor * max_position_embeddings / orig_max_position_embeddings) - (self.factor - 1)) ** (self.dim / (self.dim - 2))
113
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
114
+
115
+ self.max_seq_len_cached = orig_max_position_embeddings
116
+
117
+ elif config.rope_type == 'dynamic_ntk':
118
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
119
+ self.original_inv_freq = inv_freq
120
+ self.max_seq_len_cached = self.config.orig_max_position_embeddings
121
+
122
+ else:
123
+ raise ValueError(f"Not support rope_type: {config.rope_type}")
124
+
125
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
126
+
127
+
128
+ def _dynamic_frequency_update(self, position_ids, device):
129
+ """
130
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
131
+ 1 - growing beyond the cached sequence length (allow scaling)
132
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
133
+ """
134
+
135
+ seq_len = torch.max(position_ids) + 1
136
+ if seq_len > self.max_seq_len_cached: # growth
137
+ base = self.base * ((self.factor * seq_len / self.config.orig_max_position_embeddings) - (self.factor - 1)) ** (self.dim / (self.dim - 2))
138
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
139
+
140
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
141
+ self.max_seq_len_cached = seq_len
142
+
143
+ if seq_len < self.config.orig_max_position_embeddings and self.max_seq_len_cached > self.config.orig_max_position_embeddings: # reset
144
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
145
+ self.max_seq_len_cached = self.config.orig_max_position_embeddings
146
+
147
+
148
+ @torch.no_grad()
149
+ def forward(self, x, position_ids):
150
+ if self.rope_type == 'dynamic_ntk':
151
+ self._dynamic_frequency_update(position_ids, device=x.device)
152
+
153
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
154
+ position_ids_expanded = position_ids[:, None, :].float()
155
+
156
+ device_type = x.device.type
157
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
158
+ with torch.autocast(device_type=device_type, enabled=False):
159
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
160
+ emb = torch.cat((freqs, freqs), dim=-1)
161
+ cos = emb.cos()
162
+ sin = emb.sin()
163
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
164
+
165
+
166
+ def rotate_half(x):
167
+ """Rotates half the hidden dims of the input."""
168
+ x1 = x[..., : x.shape[-1] // 2]
169
+ x2 = x[..., x.shape[-1] // 2 :]
170
+ return torch.cat((-x2, x1), dim=-1)
171
+
172
+
173
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
174
+ """Applies Rotary Position Embedding to the query and key tensors."""
175
+ cos = cos.unsqueeze(unsqueeze_dim)
176
+ sin = sin.unsqueeze(unsqueeze_dim)
177
+ if q is not None:
178
+ q_embed = (q * cos) + (rotate_half(q) * sin)
179
+
180
+ else:
181
+ q_embed = None
182
+
183
+ if k is not None:
184
+ k_embed = (k * cos) + (rotate_half(k) * sin)
185
+ else:
186
+ k_embed = None
187
+ return q_embed, k_embed
188
+
189
+
190
+
191
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
192
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
193
+ if n_rep == 1:
194
+ return hidden_states
195
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
196
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
197
+
198
+
199
+
200
+ class AttentionDynamicCache(DynamicCache):
201
+
202
+ def __init__(self, config, batch_size, dtype=torch.float16, device=None, layer_type=None):
203
+ self.dtype = dtype
204
+
205
+ self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
206
+ self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
207
+
208
+ def update(
209
+ self,
210
+ key_states: torch.Tensor,
211
+ value_states: torch.Tensor,
212
+ layer_idx: int,
213
+ cache_kwargs: Optional[Dict[str, Any]] = None,
214
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
215
+
216
+ if self.key_cache[layer_idx].shape[-1] == 0:
217
+ self.key_cache[layer_idx] = key_states
218
+ self.value_cache[layer_idx] = value_states
219
+ else:
220
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
221
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
222
+
223
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
224
+
225
+ def get_seq_length(self, layer_idx=None) -> int:
226
+ if layer_idx is None:
227
+ max_key_len = max(cache.shape[-2] for cache in self.key_cache)
228
+ return max_key_len
229
+
230
+ if self.key_cache[layer_idx].shape[-1] == 0:
231
+ return 0
232
+
233
+ return self.key_cache[layer_idx].shape[-2]
234
+
235
+
236
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralAttention
237
+ class NemotronFlashAttention(nn.Module):
238
+
239
+ def __init__(self, config: NemotronFlashConfig, layer_idx: Optional[int] = None, input_hidden_size=None, output_hidden_size=None):
240
+ super().__init__()
241
+ self.config = config
242
+ self.layer_idx = layer_idx
243
+ if layer_idx is None:
244
+ logger.warning_once(
245
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
246
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
247
+ "when creating this class."
248
+ )
249
+
250
+ self.hidden_size = config.attn_hidden_size if config.attn_hidden_size > 0 else config.hidden_size
251
+ self.num_heads = config.num_attention_heads
252
+ self.head_dim = self.hidden_size // self.num_heads
253
+ self.max_position_embeddings = config.max_position_embeddings
254
+ self.rope_theta = config.rope_theta
255
+
256
+ self.kq_head_dim = config.kq_head_dim if config.kq_head_dim > 0 else self.head_dim
257
+ self.v_head_dim = config.v_head_dim if config.v_head_dim > 0 else self.head_dim
258
+
259
+ self.num_key_value_heads = config.num_key_value_heads
260
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
261
+ self.is_causal = True
262
+ self.attention_dropout = config.attention_dropout
263
+
264
+ if (self.head_dim * self.num_heads) != self.hidden_size and self.kq_head_dim == self.head_dim:
265
+ raise ValueError(
266
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
267
+ f" and `num_heads`: {self.num_heads})."
268
+ )
269
+
270
+ self.q_proj = nn.Linear(self.hidden_size if input_hidden_size is None else input_hidden_size, self.num_heads * self.kq_head_dim, bias=False)
271
+ self.k_proj = nn.Linear(self.hidden_size if input_hidden_size is None else input_hidden_size, self.num_key_value_heads * self.kq_head_dim, bias=False)
272
+ self.v_proj = nn.Linear(self.hidden_size if input_hidden_size is None else input_hidden_size, self.num_key_value_heads * self.v_head_dim, bias=False)
273
+
274
+ if output_hidden_size is None:
275
+ output_hidden_size = self.hidden_size
276
+
277
+ self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, output_hidden_size, bias=False)
278
+
279
+ if self.config.kq_norm == "rms":
280
+ self.k_norm = NemotronFlashRMSNorm(self.kq_head_dim)
281
+ self.q_norm = NemotronFlashRMSNorm(self.kq_head_dim)
282
+ elif self.config.kq_norm == "none":
283
+ self.k_norm = None
284
+ self.q_norm = None
285
+ else:
286
+ raise NotImplementedError(f"Unknown kq_norm: {self.config.kq_norm}")
287
+
288
+ if self.config.rope:
289
+ self._init_rope()
290
+
291
+ def _init_rope(self):
292
+ self.rotary_emb = LlamaRotaryEmbedding(
293
+ config=self.config,
294
+ dim=self.kq_head_dim,
295
+ base=self.rope_theta,
296
+ device=torch.device("cuda"),
297
+ )
298
+
299
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
300
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
301
+
302
+ def forward(
303
+ self,
304
+ hidden_states: torch.Tensor,
305
+ attention_mask: Optional[torch.Tensor] = None,
306
+ position_ids: Optional[torch.LongTensor] = None,
307
+ past_key_value: Optional[Cache] = None,
308
+ output_attentions: bool = False,
309
+ use_cache: bool = False,
310
+ use_swa=False,
311
+ query_states = None,
312
+ key_states=None,
313
+ value_states=None,
314
+ **kwargs,
315
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
316
+ raise NotImplementedError("NemotronFlashAttention is an abstract class. Use one of the subclasses.")
317
+
318
+
319
+
320
+ def _get_unpad_data(attention_mask):
321
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
322
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
323
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
324
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
325
+ return (
326
+ indices,
327
+ cu_seqlens,
328
+ max_seqlen_in_batch,
329
+ )
330
+
331
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2
332
+ class NemotronFlashFlashAttention2(NemotronFlashAttention):
333
+
334
+ def __init__(self, *args, **kwargs):
335
+ super().__init__(*args, **kwargs)
336
+
337
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states: torch.Tensor = None,
342
+ attention_mask: Optional[torch.Tensor] = None,
343
+ position_ids: Optional[torch.LongTensor] = None,
344
+ past_key_value: Optional[Cache] = None,
345
+ output_attentions: bool = False,
346
+ use_cache: bool = False,
347
+ use_swa=False,
348
+ query_states = None,
349
+ key_states=None,
350
+ value_states=None,
351
+ **kwargs,
352
+ ):
353
+
354
+ if "padding_mask" in kwargs:
355
+ warnings.warn(
356
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
357
+ )
358
+
359
+ # overwrite attention_mask with padding_mask
360
+ attention_mask = kwargs.pop("padding_mask")
361
+
362
+ bsz, q_len, _ = hidden_states.size()
363
+
364
+ query_states = self.q_proj(hidden_states)
365
+
366
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.kq_head_dim).transpose(1, 2).contiguous()
367
+
368
+ if self.q_norm is not None:
369
+ query_states = self.q_norm(query_states)
370
+
371
+ if self.config.rope:
372
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
373
+ query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin)
374
+
375
+ key_states = self.k_proj(hidden_states)
376
+ value_states = self.v_proj(hidden_states)
377
+
378
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.kq_head_dim).transpose(1, 2)
379
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.v_head_dim).transpose(1, 2)
380
+
381
+ if self.k_norm is not None:
382
+ key_states = self.k_norm(key_states)
383
+
384
+ if self.config.rope:
385
+ _, key_states = apply_rotary_pos_emb(None, key_states, cos, sin)
386
+
387
+
388
+ kv_seq_len = key_states.shape[-2]
389
+ if past_key_value is not None:
390
+ if self.layer_idx is None:
391
+ raise ValueError(
392
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
393
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
394
+ "with a layer index."
395
+ )
396
+ kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
397
+
398
+ use_sliding_windows = (
399
+ _flash_supports_window_size
400
+ and getattr(self.config, "sliding_window", None) is not None
401
+ and kv_seq_len > self.config.sliding_window
402
+ and use_swa
403
+ )
404
+
405
+ if not _flash_supports_window_size:
406
+ logger.warning_once(
407
+ "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
408
+ " make sure to upgrade flash-attn library."
409
+ )
410
+
411
+ swa_processed_flag = False
412
+ if past_key_value is not None and use_cache:
413
+ kv_layer_idx = self.layer_idx
414
+
415
+ cache_has_contents = past_key_value.get_seq_length(kv_layer_idx) > 0
416
+
417
+ if (
418
+ getattr(self.config, "sliding_window", None) is not None
419
+ and kv_seq_len > self.config.sliding_window
420
+ and cache_has_contents
421
+ and use_swa
422
+ ):
423
+ slicing_tokens = 1 - self.config.sliding_window
424
+
425
+ past_key = past_key_value[kv_layer_idx][0]
426
+ past_value = past_key_value[kv_layer_idx][1]
427
+
428
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
429
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
430
+
431
+ past_key_value.key_cache[kv_layer_idx] = past_key
432
+ past_key_value.value_cache[kv_layer_idx] = past_value
433
+
434
+ if attention_mask is not None:
435
+ attention_mask = attention_mask[:, slicing_tokens:]
436
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
437
+
438
+ swa_processed_flag = True
439
+
440
+ key_states, value_states = past_key_value.update(key_states, value_states, kv_layer_idx)
441
+
442
+ key_states_no_repeat = key_states
443
+ value_states_no_repeat = value_states
444
+
445
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
446
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
447
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
448
+
449
+ input_dtype = query_states.dtype
450
+ if input_dtype == torch.float32:
451
+ if torch.is_autocast_enabled():
452
+ target_dtype = torch.get_autocast_gpu_dtype()
453
+ elif hasattr(self.config, "_pre_quantization_dtype"):
454
+ target_dtype = self.config._pre_quantization_dtype
455
+ else:
456
+ target_dtype = self.q_proj.weight.dtype
457
+
458
+ logger.warning_once(
459
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
460
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
461
+ f" {target_dtype}."
462
+ )
463
+
464
+ query_states = query_states.to(target_dtype)
465
+ key_states = key_states.to(target_dtype)
466
+ value_states = value_states.to(target_dtype)
467
+
468
+ # Reashape to the expected shape for Flash Attention
469
+ query_states = query_states.transpose(1, 2) # (batch, slen, num_heads, head_dim)
470
+ key_states = key_states.transpose(1, 2) # (batch, slen, num_heads, head_dim)
471
+ value_states = value_states.transpose(1, 2) # (batch, slen, num_heads, head_dim)
472
+
473
+ attn_output = self._flash_attention_forward(
474
+ query_states,
475
+ key_states,
476
+ value_states,
477
+ attention_mask,
478
+ q_len,
479
+ dropout=dropout_rate,
480
+ use_sliding_windows=use_sliding_windows and not swa_processed_flag,
481
+ )
482
+
483
+ v_dim = value_states.shape[-2] * value_states.shape[-1]
484
+ attn_output = attn_output.reshape(-1, q_len, v_dim).contiguous()
485
+
486
+ attn_output = self.o_proj(attn_output)
487
+
488
+ if not output_attentions:
489
+ attn_weights = None
490
+
491
+ return attn_output, attn_weights, past_key_value, (key_states_no_repeat, value_states_no_repeat)
492
+
493
+ def _flash_attention_forward(
494
+ self,
495
+ query_states,
496
+ key_states,
497
+ value_states,
498
+ attention_mask,
499
+ query_length,
500
+ dropout=0.0,
501
+ softmax_scale=None,
502
+ use_sliding_windows=False,
503
+ ):
504
+ """
505
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
506
+ first unpad the input, then computes the attention scores and pad the final attention scores.
507
+
508
+ Args:
509
+ query_states (`torch.Tensor`):
510
+ Input query states to be passed to Flash Attention API
511
+ key_states (`torch.Tensor`):
512
+ Input key states to be passed to Flash Attention API
513
+ value_states (`torch.Tensor`):
514
+ Input value states to be passed to Flash Attention API
515
+ attention_mask (`torch.Tensor`):
516
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
517
+ position of padding tokens and 1 for the position of non-padding tokens.
518
+ dropout (`int`, *optional*):
519
+ Attention dropout
520
+ softmax_scale (`float`, *optional*):
521
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
522
+ use_sliding_windows (`bool`, *optional*):
523
+ Whether to activate sliding window attention.
524
+ """
525
+ if not self._flash_attn_uses_top_left_mask:
526
+ causal = self.is_causal
527
+ else:
528
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
529
+ causal = self.is_causal and query_length != 1
530
+
531
+ if attention_mask is not None:
532
+ batch_size = query_states.shape[0]
533
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
534
+ query_states, key_states, value_states, attention_mask, query_length
535
+ )
536
+
537
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
538
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
539
+
540
+ if not use_sliding_windows:
541
+ attn_output_unpad = flash_attn_varlen_func(
542
+ query_states,
543
+ key_states,
544
+ value_states,
545
+ cu_seqlens_q=cu_seqlens_q,
546
+ cu_seqlens_k=cu_seqlens_k,
547
+ max_seqlen_q=max_seqlen_in_batch_q,
548
+ max_seqlen_k=max_seqlen_in_batch_k,
549
+ dropout_p=dropout,
550
+ softmax_scale=softmax_scale,
551
+ causal=causal,
552
+ )
553
+ else:
554
+ attn_output_unpad = flash_attn_varlen_func(
555
+ query_states,
556
+ key_states,
557
+ value_states,
558
+ cu_seqlens_q=cu_seqlens_q,
559
+ cu_seqlens_k=cu_seqlens_k,
560
+ max_seqlen_q=max_seqlen_in_batch_q,
561
+ max_seqlen_k=max_seqlen_in_batch_k,
562
+ dropout_p=dropout,
563
+ softmax_scale=softmax_scale,
564
+ causal=causal,
565
+ window_size=(self.config.sliding_window, self.config.sliding_window),
566
+ )
567
+
568
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
569
+ else:
570
+ if not use_sliding_windows:
571
+ attn_output = flash_attn_func(
572
+ query_states,
573
+ key_states,
574
+ value_states,
575
+ dropout,
576
+ softmax_scale=softmax_scale,
577
+ causal=causal,
578
+ )
579
+ else:
580
+ attn_output = flash_attn_func(
581
+ query_states,
582
+ key_states,
583
+ value_states,
584
+ dropout,
585
+ softmax_scale=softmax_scale,
586
+ causal=causal,
587
+ window_size=(self.config.sliding_window, self.config.sliding_window),
588
+ )
589
+
590
+ return attn_output
591
+
592
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
593
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
594
+
595
+ # On the first iteration we need to properly re-create the padding mask
596
+ # by slicing it on the proper place
597
+ if kv_seq_len != attention_mask.shape[-1]:
598
+ attention_mask_num_tokens = attention_mask.shape[-1]
599
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
600
+
601
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
602
+
603
+ if not self.training and not type(key_layer) == torch.Tensor: ## this is for handling Mamba2 with output type <class 'mamba_ssm.ops.triton.layernorm_gated.tTensor'>
604
+ key_layer = torch.tensor(key_layer.clone())
605
+ value_layer = torch.tensor(value_layer.clone())
606
+ query_layer = torch.tensor(query_layer.clone())
607
+
608
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
609
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
610
+
611
+ if query_length == kv_seq_len:
612
+ query_layer = index_first_axis(
613
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
614
+ )
615
+ cu_seqlens_q = cu_seqlens_k
616
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
617
+ indices_q = indices_k
618
+ elif query_length == 1:
619
+ max_seqlen_in_batch_q = 1
620
+ cu_seqlens_q = torch.arange(
621
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
622
+ ) # There is a memcpy here, that is very bad.
623
+ indices_q = cu_seqlens_q[:-1]
624
+ query_layer = query_layer.squeeze(1)
625
+ else:
626
+ # The -q_len: slice assumes left padding.
627
+ attention_mask = attention_mask[:, -query_length:]
628
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
629
+
630
+ return (
631
+ query_layer,
632
+ key_layer,
633
+ value_layer,
634
+ indices_q,
635
+ (cu_seqlens_q, cu_seqlens_k),
636
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
637
+ )
638
+
639
+
640
+
641
+ class NemotronFlashSDPAAttention(nn.Module):
642
+
643
+ def __init__(self, config, layer_idx: int, reuse_kv=False):
644
+ super().__init__()
645
+ self.config = config
646
+ self.layer_idx = layer_idx
647
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
648
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
649
+ self.scaling = self.head_dim**-0.5
650
+ self.attention_dropout = config.attention_dropout
651
+ self.is_causal = True
652
+
653
+ self.q_proj = nn.Linear(
654
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=False
655
+ )
656
+ self.k_proj = nn.Linear(
657
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
658
+ )
659
+ self.v_proj = nn.Linear(
660
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
661
+ )
662
+ self.o_proj = nn.Linear(
663
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
664
+ )
665
+
666
+ self.sliding_window = self.config.sliding_window if self.layer_idx not in self.config.global_attn_idx else None
667
+
668
+ self.rotary_emb = NemotronFlashRotaryEmbedding(config=config)
669
+
670
+ def forward(
671
+ self,
672
+ hidden_states: torch.Tensor,
673
+ attention_mask: Optional[torch.Tensor],
674
+ position_ids: Optional[torch.LongTensor] = None,
675
+ past_key_value: Optional[Cache] = None,
676
+ **kwargs,
677
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
678
+ input_shape = hidden_states.shape[:-1]
679
+ hidden_shape = (*input_shape, -1, self.head_dim)
680
+
681
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
682
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
683
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
684
+
685
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
686
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
687
+
688
+ if past_key_value is not None:
689
+ past_seen_tokens = past_key_value.get_seq_length()
690
+ cache_position = torch.arange(
691
+ past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
692
+ )
693
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
694
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
695
+
696
+ attention_interface = ALL_ATTENTION_FUNCTIONS['flash_attention_2']
697
+
698
+ attn_output, attn_weights = attention_interface(
699
+ self,
700
+ query_states,
701
+ key_states,
702
+ value_states,
703
+ attention_mask,
704
+ dropout=0.0 if not self.training else self.attention_dropout,
705
+ scaling=self.scaling,
706
+ sliding_window=self.sliding_window,
707
+ **kwargs,
708
+ )
709
+
710
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
711
+ attn_output = self.o_proj(attn_output)
712
+
713
+ return attn_output, attn_weights, past_key_value, (key_states, value_states)
714
+
715
+
716
+ class NemotronFlashRotaryEmbedding(nn.Module):
717
+ def __init__(self, config, device=None):
718
+ super().__init__()
719
+
720
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
721
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
722
+ else:
723
+ self.rope_type = "default"
724
+ self.max_seq_len_cached = config.max_position_embeddings
725
+ self.original_max_seq_len = config.max_position_embeddings
726
+
727
+ self.config = config
728
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
729
+
730
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
731
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
732
+ self.original_inv_freq = self.inv_freq
733
+
734
+ @torch.no_grad()
735
+ @dynamic_rope_update
736
+ def forward(self, x, position_ids):
737
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
738
+ position_ids_expanded = position_ids[:, None, :].float()
739
+
740
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
741
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
742
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
743
+ emb = torch.cat((freqs, freqs), dim=-1)
744
+ cos = emb.cos() * self.attention_scaling
745
+ sin = emb.sin() * self.attention_scaling
746
+
747
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
748
+
749
+
750
+ ## Interface to use TRTLLM AutoDeploy attention kernel, which enables CUDA Graph capture
751
+ class NemotronFlashFusedMHA(NemotronFlashAttention):
752
+ def __init__(self, *args, **kwargs):
753
+ super().__init__(*args, **kwargs)
754
+
755
+ self.fused_mha_interface = fused_mha_interface
756
+
757
+ def init_kv_cache(self, max_batch_size, max_seq_len, page_size=-1):
758
+ if hasattr(self, 'k_cache'):
759
+ del self.k_cache
760
+ del self.v_cache
761
+
762
+ if hasattr(self, 'page_table') and self.page_table is not None:
763
+ del self.page_table
764
+
765
+ import gc
766
+ gc.collect()
767
+
768
+ torch.cuda.empty_cache()
769
+
770
+ if page_size is not None and page_size > 0:
771
+ batch_max_pages = (max_seq_len + page_size - 1) // page_size
772
+ cache_max_pages = (max_batch_size * max_seq_len + page_size - 1) // page_size
773
+ self.k_cache = torch.zeros(cache_max_pages, page_size, self.num_key_value_heads, self.kq_head_dim).to(self.q_proj.weight)
774
+ self.v_cache = torch.zeros(cache_max_pages, page_size, self.num_key_value_heads, self.v_head_dim).to(self.q_proj.weight)
775
+
776
+ self.page_table = torch.zeros(max_batch_size, batch_max_pages, device=self.q_proj.weight.device, dtype=torch.int32)
777
+ else:
778
+ self.k_cache = torch.zeros(max_batch_size, max_seq_len, self.num_key_value_heads, self.kq_head_dim).to(self.q_proj.weight)
779
+ self.v_cache = torch.zeros(max_batch_size, max_seq_len, self.num_key_value_heads, self.v_head_dim).to(self.q_proj.weight)
780
+
781
+ self.page_table = None
782
+
783
+ self.max_seq_len = max_seq_len
784
+
785
+
786
+ def reset_kv_cache(self):
787
+ self.k_cache = self.k_cache.zero_()
788
+ self.v_cache = self.v_cache.zero_()
789
+
790
+ if self.page_table is not None:
791
+ self.page_table = self.page_table.zero_()
792
+
793
+
794
+ def forward(
795
+ self,
796
+ hidden_states: torch.Tensor = None,
797
+ attention_mask: Optional[torch.Tensor] = None,
798
+ position_ids: Optional[torch.LongTensor] = None,
799
+ past_key_value: Optional[Cache] = None,
800
+ output_attentions: bool = False,
801
+ use_cache: bool = False,
802
+ use_swa=False,
803
+ query_states = None,
804
+ key_states=None,
805
+ value_states=None,
806
+ **kwargs,
807
+ ):
808
+
809
+ if not hasattr(self, 'k_cache'):
810
+ self.init_kv_cache(max_batch_size=1, max_seq_len=8000)
811
+
812
+ bsz, q_len, _ = hidden_states.size()
813
+
814
+ query_states = self.q_proj(hidden_states)
815
+
816
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.kq_head_dim).transpose(1, 2).contiguous()
817
+
818
+ if self.q_norm is not None:
819
+ query_states = self.q_norm(query_states)
820
+
821
+ if self.config.rope:
822
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
823
+ query_states, _ = apply_rotary_pos_emb(query_states, None, cos, sin)
824
+
825
+ key_states = self.k_proj(hidden_states)
826
+ value_states = self.v_proj(hidden_states)
827
+
828
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.kq_head_dim).transpose(1, 2)
829
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.v_head_dim).transpose(1, 2)
830
+
831
+ if self.k_norm is not None:
832
+ key_states = self.k_norm(key_states)
833
+
834
+ if self.config.rope:
835
+ _, key_states = apply_rotary_pos_emb(None, key_states, cos, sin)
836
+
837
+ key_states_no_repeat = key_states
838
+ value_states_no_repeat = value_states
839
+
840
+ query_states = query_states.transpose(1, 2) # (batch, slen, num_heads, head_dim)
841
+ key_states = key_states.transpose(1, 2) # (batch, slen, num_kv_heads, head_dim)
842
+ value_states = value_states.transpose(1, 2) # (batch, slen, num_kv_heads, head_dim)
843
+
844
+ if self.k_cache.device != query_states.device:
845
+ self.k_cache = self.k_cache.to(query_states)
846
+ self.v_cache = self.v_cache.to(query_states)
847
+
848
+ attn_output = self.fused_mha_interface(
849
+ query_states,
850
+ key_states,
851
+ value_states,
852
+ k_cache=self.k_cache,
853
+ v_cache=self.v_cache,
854
+ page_table=self.page_table,
855
+ max_seq_len=self.max_seq_len,
856
+ position_ids=position_ids,
857
+ )
858
+
859
+ v_dim = query_states.shape[-2] * value_states.shape[-1]
860
+ attn_output = attn_output.reshape(bsz, q_len, v_dim).contiguous()
861
+
862
+ attn_output = self.o_proj(attn_output)
863
+
864
+ if not output_attentions:
865
+ attn_weights = None
866
+
867
+ return attn_output, attn_weights, past_key_value, (key_states_no_repeat, value_states_no_repeat)
868
+
869
+
870
+ JAMBA_ATTENTION_CLASSES = {
871
+ "flash_attention_2": NemotronFlashFlashAttention2,
872
+ "fused_mha": NemotronFlashFusedMHA,
873
+ "sdpa": NemotronFlashSDPAAttention,
874
+ }
875
+
876
+ class NemotronFlashMLP(nn.Module):
877
+ def __init__(self, config: NemotronFlashConfig, layer_idx: int):
878
+ super().__init__()
879
+ self.config = config
880
+ self.act_fn_name = config.mlp_hidden_act
881
+ self.act_fn = ACT2FN[self.act_fn_name]
882
+
883
+ if config.ffn_expand_ratio is not None:
884
+ self.ffn_dim = int(config.ffn_expand_ratio * config.hidden_size) // 128 * 128
885
+ else:
886
+ self.ffn_dim = config.intermediate_size
887
+
888
+ self.hidden_dim = config.hidden_size
889
+
890
+ self.layer_idx = layer_idx
891
+
892
+ if self.act_fn_name == "silu":
893
+ self.gate_proj = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
894
+ self.down_proj = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
895
+ self.up_proj = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
896
+
897
+
898
+ def forward(self, x):
899
+ if self.act_fn_name == "silu":
900
+ output = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
901
+ elif self.act_fn_name == "relu2":
902
+ output = self.down_proj(self.act_fn(self.up_proj(x)))
903
+ else:
904
+ raise NotImplementedError(f"No such hidden_act: {self.act_fn_name}")
905
+
906
+ return output
907
+
908
+
909
+ class NemotronFlashAttentionDecoderLayer(nn.Module):
910
+ def __init__(self, config: NemotronFlashConfig, layer_idx: int,):
911
+ super().__init__()
912
+
913
+ self.config = config
914
+
915
+ self.layer_idx = layer_idx
916
+
917
+ self.self_attn = JAMBA_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx)
918
+
919
+ if self.config.intermediate_size > 0:
920
+ self.ffn = NemotronFlashMLP(config, layer_idx=layer_idx)
921
+ else:
922
+ self.ffn = None
923
+
924
+ self.input_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
925
+ self.pre_ffn_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
926
+
927
+ def forward(
928
+ self,
929
+ hidden_states: torch.Tensor,
930
+ attention_mask: Optional[torch.Tensor] = None,
931
+ position_ids: Optional[torch.LongTensor] = None,
932
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
933
+ output_attentions: Optional[bool] = False,
934
+ use_cache: Optional[bool] = False,
935
+ use_swa=False,
936
+ **kwargs,
937
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
938
+ if "padding_mask" in kwargs:
939
+ warnings.warn(
940
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
941
+ )
942
+
943
+ if position_ids is not None and position_ids.shape[1] != hidden_states.shape[1]:
944
+ position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
945
+
946
+ residual = hidden_states
947
+
948
+ if self.input_layernorm is not None:
949
+ hidden_states = self.input_layernorm(hidden_states)
950
+
951
+ hidden_states, self_attn_weights, present_key_value, current_kv = self.self_attn(
952
+ hidden_states=hidden_states,
953
+ attention_mask=attention_mask,
954
+ position_ids=position_ids,
955
+ past_key_value=past_key_value,
956
+ output_attentions=output_attentions,
957
+ use_cache=use_cache,
958
+ use_swa=use_swa,
959
+ )
960
+
961
+ hidden_states = residual + hidden_states
962
+
963
+ if self.ffn is not None:
964
+ residual = hidden_states
965
+ if self.pre_ffn_layernorm is not None:
966
+ hidden_states = self.pre_ffn_layernorm(hidden_states)
967
+ hidden_states = self.ffn(hidden_states)
968
+
969
+ hidden_states = residual + hidden_states
970
+
971
+ outputs = (hidden_states,)
972
+
973
+ if output_attentions:
974
+ outputs += (self_attn_weights,)
975
+
976
+ if use_cache:
977
+ outputs += (present_key_value,)
978
+
979
+ outputs += (current_kv,)
980
+
981
+ return outputs
982
+
983
+
984
+
985
+ class FFNDecoderLayer(nn.Module):
986
+ def __init__(self, config: NemotronFlashConfig, layer_idx: int):
987
+ super().__init__()
988
+
989
+ self.config = config
990
+ self.layer_idx = layer_idx
991
+ self.ffn = NemotronFlashMLP(config, layer_idx=layer_idx)
992
+ self.pre_ffn_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
993
+
994
+ def forward(
995
+ self,
996
+ hidden_states: torch.Tensor,
997
+ attention_mask: Optional[torch.Tensor] = None,
998
+ position_ids: Optional[torch.LongTensor] = None,
999
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1000
+ output_attentions: Optional[bool] = False,
1001
+ use_cache: Optional[bool] = False,
1002
+ use_swa=False,
1003
+ **kwargs,
1004
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1005
+ if "padding_mask" in kwargs:
1006
+ warnings.warn(
1007
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1008
+ )
1009
+
1010
+ residual = hidden_states
1011
+ if self.pre_ffn_layernorm is not None:
1012
+ hidden_states = self.pre_ffn_layernorm(hidden_states)
1013
+ hidden_states = self.ffn(hidden_states)
1014
+
1015
+ hidden_states = residual + hidden_states
1016
+
1017
+ outputs = (hidden_states,)
1018
+
1019
+ if output_attentions:
1020
+ outputs += (None,)
1021
+
1022
+ if use_cache:
1023
+ outputs += (None,)
1024
+
1025
+ return outputs
1026
+
1027
+
1028
+ class NemotronFlashMambaDecoderLayer(nn.Module):
1029
+ def __init__(self, config: NemotronFlashConfig, layer_idx: int):
1030
+ super().__init__()
1031
+
1032
+ self.config = config
1033
+ self.layer_idx = layer_idx
1034
+
1035
+ self.mamba = Mamba2(config=config, layer_idx=layer_idx)
1036
+
1037
+ self.intermediate_size = config.intermediate_size
1038
+ if self.intermediate_size > 0:
1039
+ self.ffn = NemotronFlashMLP(config, layer_idx=layer_idx)
1040
+
1041
+ self.input_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1042
+
1043
+ if self.intermediate_size > 0:
1044
+ self.pre_ffn_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1045
+ else:
1046
+ self.pre_ffn_layernorm = None
1047
+
1048
+
1049
+ def forward(
1050
+ self,
1051
+ hidden_states: torch.Tensor,
1052
+ attention_mask: Optional[torch.Tensor] = None,
1053
+ position_ids: Optional[torch.LongTensor] = None,
1054
+ past_key_value: Optional[AttentionDynamicCache] = None,
1055
+ output_attentions: Optional[bool] = False,
1056
+ use_cache: Optional[bool] = False,
1057
+ use_swa=False,
1058
+ mamba_inference_params=None,
1059
+ **kwargs,
1060
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1061
+ if "padding_mask" in kwargs:
1062
+ warnings.warn(
1063
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1064
+ )
1065
+
1066
+ if position_ids is not None and position_ids.shape[1] != hidden_states.shape[1]:
1067
+ position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
1068
+
1069
+ residual = hidden_states
1070
+
1071
+ if self.input_layernorm is not None:
1072
+ hidden_states = self.input_layernorm(hidden_states)
1073
+
1074
+ hidden_states, present_key_value = self.mamba(
1075
+ hidden_states=hidden_states,
1076
+ past_key_value=past_key_value,
1077
+ attention_mask=attention_mask,
1078
+ inference_params=mamba_inference_params,
1079
+ )
1080
+
1081
+ attn_key_value = None
1082
+
1083
+ hidden_states = residual + hidden_states
1084
+
1085
+ if self.intermediate_size > 0:
1086
+ residual = hidden_states
1087
+
1088
+ if self.pre_ffn_layernorm is not None:
1089
+ hidden_states = self.pre_ffn_layernorm(hidden_states)
1090
+
1091
+ hidden_states = self.ffn(hidden_states)
1092
+
1093
+ hidden_states = residual + hidden_states
1094
+
1095
+ outputs = (hidden_states,)
1096
+
1097
+ if use_cache:
1098
+ outputs += (present_key_value,)
1099
+
1100
+ outputs += (attn_key_value,)
1101
+
1102
+ return outputs
1103
+
1104
+ def _get_past_seqlen(self, past_key_value, seqlen):
1105
+ if past_key_value is None:
1106
+ return seqlen
1107
+ past_seqlen = past_key_value.get_seq_length(self.layer_idx)
1108
+
1109
+ if past_seqlen == 0:
1110
+ return seqlen
1111
+
1112
+ return past_seqlen
1113
+
1114
+
1115
+
1116
+ class NemotronFlashHybridDecoderLayer(nn.Module):
1117
+ def __init__(self, config: NemotronFlashConfig, layer_idx: int):
1118
+ super().__init__()
1119
+
1120
+ self.config = config
1121
+
1122
+ self.layer_idx = layer_idx
1123
+
1124
+ if config.hybrid_decoder_layer == 'mamba':
1125
+ self.mamba = Mamba2(config=config, layer_idx=layer_idx)
1126
+ if config.hybrid_decoder_layer == 'deltanet':
1127
+ if config.layer_types is not None:
1128
+ deltanet_idx = sum(1 for i in range(layer_idx) if config.layer_types[i] == 'deltanet')
1129
+ else:
1130
+ deltanet_idx = layer_idx
1131
+
1132
+ self.gla = DeltaNet(hidden_size=config.hidden_size, num_heads=config.num_attention_heads, layer_idx=deltanet_idx, config=self.config)
1133
+ else:
1134
+ raise ValueError(f"Not supported: {config.hybrid_decoder_layer}")
1135
+
1136
+ self.config = config
1137
+
1138
+ if self.config.intermediate_size > 0:
1139
+ self.ffn = NemotronFlashMLP(config, layer_idx=layer_idx)
1140
+ self.pre_ffn_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1141
+ else:
1142
+ self.ffn = None
1143
+ self.pre_ffn_layernorm = None
1144
+
1145
+ self.input_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1146
+
1147
+
1148
+ def forward(
1149
+ self,
1150
+ hidden_states: torch.Tensor,
1151
+ attention_mask: Optional[torch.Tensor] = None,
1152
+ position_ids: Optional[torch.LongTensor] = None,
1153
+ past_key_value: Optional[AttentionDynamicCache] = None,
1154
+ output_attentions: Optional[bool] = False,
1155
+ use_cache: Optional[bool] = False,
1156
+ fla_past_key_values = None,
1157
+ mamba_inference_params = None,
1158
+ use_swa=False,
1159
+ **kwargs,
1160
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1161
+ if "padding_mask" in kwargs:
1162
+ warnings.warn(
1163
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1164
+ )
1165
+
1166
+ residual = hidden_states
1167
+
1168
+ hidden_states = self.input_layernorm(hidden_states)
1169
+
1170
+ if self.config.hybrid_decoder_layer == 'mamba':
1171
+ hybrid_op_hidden_states, mamba_present_key_value = self.mamba(
1172
+ hidden_states=hidden_states,
1173
+ past_key_value=past_key_value,
1174
+ attention_mask=attention_mask,
1175
+ inference_params=mamba_inference_params,
1176
+ )
1177
+
1178
+ else:
1179
+ hybrid_op_hidden_states, _, fla_past_key_values = self.gla(
1180
+ hidden_states=hidden_states,
1181
+ attention_mask=attention_mask,
1182
+ past_key_values=fla_past_key_values,
1183
+ use_cache=use_cache,
1184
+ )
1185
+
1186
+ self_attn_weights = self_attn_present_key_value = current_kv = None
1187
+
1188
+ hidden_states = residual + hybrid_op_hidden_states
1189
+
1190
+ if self.ffn is not None:
1191
+ residual = hidden_states
1192
+ hidden_states = self.pre_ffn_layernorm(hidden_states)
1193
+
1194
+ hidden_states = self.ffn(hidden_states)
1195
+
1196
+ hidden_states = residual + hidden_states
1197
+
1198
+ outputs = (hidden_states,)
1199
+
1200
+ if output_attentions:
1201
+ outputs += (self_attn_weights,)
1202
+
1203
+ if use_cache:
1204
+ outputs += (self_attn_present_key_value,)
1205
+
1206
+ outputs += (current_kv,)
1207
+
1208
+
1209
+ return outputs
1210
+
1211
+
1212
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel
1213
+ class NemotronFlashPreTrainedModel(PreTrainedModel):
1214
+ config_class = NemotronFlashConfig
1215
+ base_model_prefix = "model"
1216
+ supports_gradient_checkpointing = True
1217
+ _no_split_modules = ["NemotronFlashAttentionDecoderLayer", "NemotronFlashMambaDecoderLayer"]
1218
+ _skip_keys_device_placement = "past_key_values"
1219
+ _supports_flash_attn_2 = True
1220
+ _supports_sdpa = True
1221
+ _supports_cache_class = True
1222
+
1223
+ def _init_weights(self, module):
1224
+ std = self.config.initializer_range
1225
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
1226
+ module.weight.data.normal_(mean=0.0, std=std)
1227
+ if module.bias is not None:
1228
+ module.bias.data.zero_()
1229
+ elif isinstance(module, nn.Embedding):
1230
+ module.weight.data.normal_(mean=0.0, std=std)
1231
+ if module.padding_idx is not None:
1232
+ module.weight.data[module.padding_idx].zero_()
1233
+
1234
+
1235
+ # Adapted from transformers.models.mistral.modeling_mistral.MistralModel
1236
+ class NemotronFlashModel(NemotronFlashPreTrainedModel):
1237
+
1238
+ def __init__(self, config: NemotronFlashConfig):
1239
+ super().__init__(config)
1240
+
1241
+ config.attn_implementation = config.attn_implementation_new
1242
+ config._attn_implementation = config.attn_implementation_new
1243
+
1244
+ self.config = config
1245
+
1246
+ self.padding_idx = config.pad_token_id
1247
+ self.vocab_size = config.vocab_size
1248
+
1249
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1250
+
1251
+ decoder_layers = []
1252
+
1253
+ layer_type = []
1254
+ for i in range(config.num_hidden_layers):
1255
+ if config.layer_types[i] in ['deltanet']:
1256
+ layer_type.append('m')
1257
+ config_new = copy.deepcopy(config)
1258
+ config_new.hybrid_decoder_layer = 'deltanet'
1259
+ decoder_layer = NemotronFlashHybridDecoderLayer(config_new, layer_idx=i)
1260
+ elif config.layer_types[i] in ['m', 'm2']:
1261
+ layer_type.append('m')
1262
+ decoder_layer = NemotronFlashMambaDecoderLayer(config, layer_idx=i)
1263
+ elif config.layer_types[i] == 'a':
1264
+ layer_type.append('a')
1265
+ decoder_layer = NemotronFlashAttentionDecoderLayer(config, layer_idx=i)
1266
+ elif config.layer_types[i] == 'f':
1267
+ layer_type.append('a')
1268
+ decoder_layer = FFNDecoderLayer(config, layer_idx=i)
1269
+ else:
1270
+ raise ValueError(f"Unsupported layer type {config.layer_types[i]}")
1271
+
1272
+ decoder_layers.append(decoder_layer)
1273
+
1274
+ config.layer_type = layer_type
1275
+
1276
+ if config.sliding_window is not None:
1277
+ self.sliding_window = config.sliding_window
1278
+ self.global_attn_idx = config.global_attn_idx
1279
+ else:
1280
+ self.sliding_window = None
1281
+ self.global_attn_idx = None
1282
+
1283
+ self.layers = nn.ModuleList(decoder_layers)
1284
+
1285
+ self._attn_implementation = config.attn_implementation
1286
+
1287
+ self.final_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1288
+
1289
+ if self.config.num_memory_tokens > 0:
1290
+ self.memory_tokens = nn.Parameter(torch.randn(self.config.num_memory_tokens, self.config.hidden_size))
1291
+
1292
+ self.gradient_checkpointing = False
1293
+
1294
+ self.post_init()
1295
+
1296
+ self.has_previous_state = False
1297
+
1298
+
1299
+ def get_input_embeddings(self):
1300
+ return self.embed_tokens
1301
+
1302
+ def set_input_embeddings(self, value):
1303
+ self.embed_tokens = value
1304
+
1305
+
1306
+ def forward(
1307
+ self,
1308
+ input_ids: torch.LongTensor = None,
1309
+ attention_mask: Optional[torch.Tensor] = None,
1310
+ position_ids: Optional[torch.LongTensor] = None,
1311
+ past_key_values: Optional[Union[List[torch.FloatTensor], AttentionDynamicCache]] = None,
1312
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1313
+ use_cache: Optional[bool] = None,
1314
+ output_attentions: Optional[bool] = None,
1315
+ output_hidden_states: Optional[bool] = None,
1316
+ return_dict: Optional[bool] = None,
1317
+ fla_past_key_values = None,
1318
+ mamba_inference_params = None,
1319
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
1320
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1321
+
1322
+ output_hidden_states = (
1323
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1324
+ )
1325
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1326
+
1327
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1328
+
1329
+ if input_ids is not None and inputs_embeds is not None:
1330
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1331
+ elif input_ids is not None:
1332
+ batch_size, seq_length = input_ids.shape
1333
+ elif inputs_embeds is not None:
1334
+ batch_size, seq_length, _ = inputs_embeds.shape
1335
+ else:
1336
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1337
+
1338
+ if self.gradient_checkpointing and self.training:
1339
+ if use_cache:
1340
+ logger.warning_once(
1341
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1342
+ )
1343
+ use_cache = False
1344
+
1345
+ if position_ids is None:
1346
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1347
+ position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device
1348
+ )
1349
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1350
+ else:
1351
+ if self.config.num_memory_tokens > 0 and past_key_values is not None and not self.has_previous_state:
1352
+ position_ids = position_ids.view(-1, seq_length + self.config.num_memory_tokens).long()
1353
+ else:
1354
+ position_ids = position_ids.view(-1, seq_length).long()
1355
+
1356
+ if inputs_embeds is None:
1357
+ inputs_embeds = self.embed_tokens(input_ids)
1358
+
1359
+ ori_b, ori_n = inputs_embeds.shape[0], inputs_embeds.shape[1]
1360
+
1361
+ if self.config.num_memory_tokens > 0 and (past_key_values is None or not self.has_previous_state):
1362
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b = inputs_embeds.shape[0]) # prepend the memory to every segment of m by repeating the memory tokens
1363
+ inputs_embeds, mem_packed_shape = pack((mem, inputs_embeds), 'b * d')
1364
+
1365
+ if position_ids is not None and position_ids.shape[1] != inputs_embeds.shape[1]:
1366
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
1367
+
1368
+ if attention_mask is not None and attention_mask.shape[1] < inputs_embeds.shape[1]:
1369
+ assert attention_mask.shape[1] + self.config.num_memory_tokens == inputs_embeds.shape[1]
1370
+ attention_mask = torch.cat([torch.ones(inputs_embeds.shape[0], self.config.num_memory_tokens, device=attention_mask.device), attention_mask], dim=1)
1371
+
1372
+
1373
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
1374
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1375
+ if is_padding_right:
1376
+ raise ValueError(
1377
+ "You are attempting to perform batched generation with padding_side='right'"
1378
+ " this may lead to unexpected behaviour for Flash Attention version of NemotronFlash. Make sure to "
1379
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1380
+ )
1381
+
1382
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1383
+
1384
+ hidden_states = inputs_embeds
1385
+
1386
+ all_hidden_states = () if output_hidden_states else None
1387
+ all_self_attns = () if output_attentions else None
1388
+ next_decoder_cache = None
1389
+
1390
+ for i, decoder_layer in enumerate(self.layers):
1391
+ if output_hidden_states:
1392
+ all_hidden_states += (hidden_states,)
1393
+
1394
+ if self.gradient_checkpointing and self.training:
1395
+ layer_outputs = self._gradient_checkpointing_func(
1396
+ decoder_layer.__call__,
1397
+ hidden_states,
1398
+ attention_mask,
1399
+ position_ids,
1400
+ past_key_values,
1401
+ output_attentions,
1402
+ use_cache,
1403
+ )
1404
+ else:
1405
+ layer_outputs = decoder_layer(
1406
+ hidden_states,
1407
+ attention_mask=attention_mask,
1408
+ position_ids=position_ids,
1409
+ past_key_value=past_key_values,
1410
+ output_attentions=output_attentions,
1411
+ use_cache=use_cache,
1412
+ use_swa=self.sliding_window is not None and i not in self.global_attn_idx,
1413
+ fla_past_key_values=fla_past_key_values,
1414
+ mamba_inference_params=mamba_inference_params,
1415
+ )
1416
+
1417
+ hidden_states = layer_outputs[0]
1418
+
1419
+ if use_cache:
1420
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1421
+
1422
+ if output_attentions:
1423
+ all_self_attns += (layer_outputs[1],)
1424
+
1425
+ if self.final_layernorm is not None:
1426
+ hidden_states = self.final_layernorm(hidden_states)
1427
+
1428
+ if output_hidden_states:
1429
+ all_hidden_states += (hidden_states,)
1430
+
1431
+ if self.config.num_memory_tokens > 0 and (past_key_values is None or not self.has_previous_state):
1432
+ mem, hidden_states = unpack(hidden_states, mem_packed_shape, 'b * d')
1433
+ hidden_states = hidden_states[:, :ori_n, :]
1434
+
1435
+ if past_key_values is not None and not self.has_previous_state:
1436
+ self.has_previous_state = True
1437
+
1438
+ next_cache = None
1439
+ if use_cache:
1440
+ next_cache = next_decoder_cache
1441
+
1442
+ if not return_dict:
1443
+ return tuple(
1444
+ v
1445
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1446
+ if v is not None
1447
+ )
1448
+ return MoeModelOutputWithPast(
1449
+ last_hidden_state=hidden_states,
1450
+ past_key_values=past_key_values if (fla_past_key_values is None and mamba_inference_params is None) else (past_key_values, fla_past_key_values, mamba_inference_params),
1451
+ hidden_states=all_hidden_states,
1452
+ attentions=all_self_attns,
1453
+ )
1454
+
1455
+
1456
+ # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->NemotronFlash
1457
+ class NemotronFlashForCausalLM(NemotronFlashPreTrainedModel, GenerationMixin):
1458
+ _tied_weights_keys = ["lm_head.weight"]
1459
+
1460
+ def __init__(self, config: NemotronFlashConfig):
1461
+ super().__init__(config)
1462
+ self.config = config
1463
+ self.model = NemotronFlashModel(config)
1464
+ self.vocab_size = config.vocab_size
1465
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1466
+
1467
+ self.post_init()
1468
+
1469
+ def get_input_embeddings(self):
1470
+ return self.model.embed_tokens
1471
+
1472
+ def set_input_embeddings(self, value):
1473
+ self.model.embed_tokens = value
1474
+
1475
+ def get_output_embeddings(self):
1476
+ return self.lm_head
1477
+
1478
+ def set_output_embeddings(self, new_embeddings):
1479
+ self.lm_head = new_embeddings
1480
+
1481
+ def set_decoder(self, decoder):
1482
+ self.model = decoder
1483
+
1484
+ def get_decoder(self):
1485
+ return self.model
1486
+
1487
+ @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1488
+ def forward(
1489
+ self,
1490
+ input_ids: torch.LongTensor = None,
1491
+ attention_mask: Optional[torch.Tensor] = None,
1492
+ position_ids: Optional[torch.LongTensor] = None,
1493
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1494
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1495
+ labels: Optional[torch.LongTensor] = None,
1496
+ use_cache: Optional[bool] = None,
1497
+ output_attentions: Optional[bool] = None,
1498
+ output_hidden_states: Optional[bool] = None,
1499
+ return_dict: Optional[bool] = None,
1500
+ calc_logits_for_entire_prompt: Optional[bool] = True,
1501
+ fla_past_key_values = None,
1502
+ mamba_inference_params = None,
1503
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1504
+ r"""
1505
+ Args:
1506
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1507
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1508
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1509
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1510
+
1511
+ calc_logits_for_entire_prompt (`bool`, *optional*):
1512
+ Whether or not to calculate the logits for the entire prompt, or just the last token. Only last token
1513
+ logits are needed for generation, and calculating them only for that token can save memory,
1514
+ which becomes pretty significant for long sequences.
1515
+
1516
+ Returns:
1517
+ ```"""
1518
+
1519
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1520
+
1521
+ output_hidden_states = (
1522
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1523
+ )
1524
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1525
+
1526
+ outputs = self.model(
1527
+ input_ids=input_ids,
1528
+ attention_mask=attention_mask,
1529
+ position_ids=position_ids,
1530
+ past_key_values=past_key_values,
1531
+ inputs_embeds=inputs_embeds,
1532
+ use_cache=use_cache,
1533
+ output_attentions=output_attentions,
1534
+ output_hidden_states=output_hidden_states,
1535
+ fla_past_key_values=fla_past_key_values,
1536
+ mamba_inference_params=mamba_inference_params,
1537
+ return_dict=return_dict,
1538
+ )
1539
+
1540
+ hidden_states = outputs[0]
1541
+ if calc_logits_for_entire_prompt:
1542
+ logits = self.lm_head(hidden_states)
1543
+ else:
1544
+ logits = self.lm_head(hidden_states[..., -1:, :])
1545
+
1546
+ logits = logits / self.lm_head.weight.norm(p=2, dim=1)
1547
+
1548
+ logits = logits.float()
1549
+
1550
+ loss = None
1551
+ if labels is not None:
1552
+ shift_logits = logits[..., :-1, :].contiguous()
1553
+ shift_labels = labels[..., 1:].contiguous()
1554
+
1555
+ loss_fct = CrossEntropyLoss()
1556
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1557
+ shift_labels = shift_labels.view(-1)
1558
+ # Enable model parallelism
1559
+ shift_labels = shift_labels.to(shift_logits.device)
1560
+ loss = loss_fct(shift_logits, shift_labels)
1561
+
1562
+ if not return_dict:
1563
+ output = (logits,) + outputs[1:]
1564
+ return (loss,) + output if loss is not None else output
1565
+
1566
+ return MoeCausalLMOutputWithPast(
1567
+ loss=loss,
1568
+ logits=logits,
1569
+ past_key_values=outputs.past_key_values,
1570
+ hidden_states=outputs.hidden_states,
1571
+ attentions=outputs.attentions,
1572
+ )
1573
+
1574
+ def get_init_cache(self, max_seqlen, batch_size=1):
1575
+ past_key_values = AttentionDynamicCache(
1576
+ self.config, batch_size, self.dtype, device=self.device, layer_type=self.config.layer_type
1577
+ )
1578
+
1579
+ mamba_inference_params = InferenceParams(max_seqlen=max_seqlen, max_batch_size=batch_size)
1580
+
1581
+ fla_past_key_values = fla_cache.from_legacy_cache(None)
1582
+
1583
+ return past_key_values, fla_past_key_values, mamba_inference_params
1584
+
1585
+
1586
+ def init_cuda_graph_generation(
1587
+ self,
1588
+ max_new_tokens=128,
1589
+ batch_size=1,
1590
+ device=None,
1591
+ ):
1592
+ """
1593
+ Initialize CUDA graph for generation with proper cache handling and warmup.
1594
+ This function should be called once before generation to set up the graph.
1595
+
1596
+ Args:
1597
+ max_new_tokens: Maximum number of new tokens to generate
1598
+ batch_size: Batch size for generation
1599
+ device: Device to use (defaults to model device)
1600
+
1601
+ Returns:
1602
+ generation_state: Dictionary containing all necessary state for generation
1603
+ """
1604
+ if device is None:
1605
+ device = next(self.parameters()).device
1606
+
1607
+ self.eval()
1608
+
1609
+ # Initialize caches
1610
+ max_seqlen = max_new_tokens + 2048 + self.config.num_memory_tokens # Add buffer for input
1611
+ past_key_values, fla_past_key_values, mamba_inference_params = self.get_init_cache(
1612
+ max_seqlen=max_seqlen, batch_size=batch_size
1613
+ )
1614
+
1615
+ # Initialize KV caches for all modules
1616
+ for module in self.modules():
1617
+ if hasattr(module, 'init_kv_cache'):
1618
+ module.init_kv_cache(max_batch_size=batch_size, max_seq_len=max_seqlen)
1619
+
1620
+ with torch.no_grad():
1621
+ # Warmup runs
1622
+ dummy_input = torch.ones((batch_size, 10), dtype=torch.long, device=device)
1623
+ for _ in range(10):
1624
+ self(dummy_input)
1625
+
1626
+ # Prepare static tensors for CUDA graph
1627
+ static_current_input = torch.zeros((batch_size, 1), dtype=torch.long, device=device)
1628
+ static_position_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=device)
1629
+ static_logits = torch.zeros((batch_size, self.config.vocab_size), device=device)
1630
+
1631
+ # Set up for graph capture
1632
+ self.model.has_previous_state = True
1633
+ if mamba_inference_params is not None:
1634
+ mamba_inference_params.seqlen_offset = 1
1635
+
1636
+ # Warmup runs for graph capture
1637
+ for _ in range(10):
1638
+ model_kwargs_warmup = {
1639
+ 'input_ids': static_current_input,
1640
+ 'fla_past_key_values': fla_past_key_values,
1641
+ 'mamba_inference_params': mamba_inference_params,
1642
+ 'past_key_values': past_key_values,
1643
+ 'use_cache': True,
1644
+ 'position_ids': static_position_ids,
1645
+ }
1646
+ warmup_outputs = self(**model_kwargs_warmup)
1647
+
1648
+ # Capture CUDA graph
1649
+ generation_graph = CUDAGraph()
1650
+ with torch.cuda.graph(generation_graph):
1651
+ model_kwargs_graph = {
1652
+ 'input_ids': static_current_input,
1653
+ 'fla_past_key_values': fla_past_key_values,
1654
+ 'mamba_inference_params': mamba_inference_params,
1655
+ 'past_key_values': past_key_values,
1656
+ 'use_cache': True,
1657
+ 'position_ids': static_position_ids,
1658
+ }
1659
+ graph_outputs = self(**model_kwargs_graph)
1660
+ static_logits.copy_(graph_outputs.logits[:, -1, :])
1661
+
1662
+ if fla_past_key_values is not None:
1663
+ fla_past_key_values.reset()
1664
+
1665
+ if mamba_inference_params is not None:
1666
+ mamba_inference_params.reset(mamba_inference_params.max_seqlen, mamba_inference_params.max_batch_size)
1667
+ for key in mamba_inference_params.key_value_memory_dict:
1668
+ conv_state, ssm_state = mamba_inference_params.key_value_memory_dict[key]
1669
+ conv_state.zero_()
1670
+ ssm_state.zero_()
1671
+
1672
+ for module in self.modules():
1673
+ if hasattr(module, 'reset_kv_cache'):
1674
+ module.reset_kv_cache()
1675
+
1676
+ self.model.has_previous_state = False
1677
+
1678
+ # Return generation state
1679
+ generation_state = {
1680
+ 'generation_graph': generation_graph,
1681
+ 'static_current_input': static_current_input,
1682
+ 'static_position_ids': static_position_ids,
1683
+ 'static_logits': static_logits,
1684
+ 'past_key_values': past_key_values,
1685
+ 'fla_past_key_values': fla_past_key_values,
1686
+ 'mamba_inference_params': mamba_inference_params,
1687
+ 'max_seqlen': max_seqlen,
1688
+ 'batch_size': batch_size,
1689
+ 'device': device,
1690
+ }
1691
+
1692
+ return generation_state
1693
+
1694
+ def generate_with_cuda_graph(
1695
+ self,
1696
+ input_ids,
1697
+ generation_state,
1698
+ max_new_tokens=128,
1699
+ temperature=1.0,
1700
+ top_k=0,
1701
+ top_p=0.9,
1702
+ eos_token_id=None,
1703
+ verbose=False,
1704
+ profiling=False,
1705
+ ):
1706
+ """
1707
+ Generate text using pre-initialized CUDA graph state.
1708
+
1709
+ Args:
1710
+ input_ids: Input token IDs tensor of shape (batch_size, seq_len)
1711
+ generation_state: State dictionary returned by init_cuda_graph_generation
1712
+ max_new_tokens: Maximum number of new tokens to generate
1713
+ temperature: Sampling temperature (0 for greedy)
1714
+ top_k: Top-k filtering (0 to disable)
1715
+ top_p: Top-p filtering (1.0 to disable)
1716
+ eos_token_id: End-of-sequence token ID
1717
+ pad_token_id: Padding token ID
1718
+ verbose: Whether to print generated tokens
1719
+ profiling: Whether to return timing information
1720
+
1721
+ Returns:
1722
+ generated_ids: Tensor of shape (batch_size, input_len + generated_len)
1723
+ or decode_latency if profiling=True
1724
+ """
1725
+ self.eval()
1726
+ batch_size = input_ids.shape[0]
1727
+ device = input_ids.device
1728
+
1729
+ # Extract state
1730
+ generation_graph = generation_state['generation_graph']
1731
+ static_current_input = generation_state['static_current_input']
1732
+ static_position_ids = generation_state['static_position_ids']
1733
+ static_logits = generation_state['static_logits']
1734
+ past_key_values = generation_state['past_key_values']
1735
+ fla_past_key_values = generation_state['fla_past_key_values']
1736
+ mamba_inference_params = generation_state['mamba_inference_params']
1737
+
1738
+ with torch.no_grad():
1739
+ if mamba_inference_params.seqlen_offset == 0:
1740
+ if fla_past_key_values is not None:
1741
+ fla_past_key_values.reset()
1742
+
1743
+ if mamba_inference_params is not None:
1744
+ mamba_inference_params.reset(mamba_inference_params.max_seqlen, mamba_inference_params.max_batch_size)
1745
+ for key in mamba_inference_params.key_value_memory_dict:
1746
+ conv_state, ssm_state = mamba_inference_params.key_value_memory_dict[key]
1747
+ conv_state.zero_()
1748
+ ssm_state.zero_()
1749
+
1750
+ for module in self.modules():
1751
+ if hasattr(module, 'reset_kv_cache'):
1752
+ module.reset_kv_cache()
1753
+
1754
+ self.model.has_previous_state = False
1755
+
1756
+ # Prefill phase - process input sequence
1757
+ position_ids = torch.arange(
1758
+ self.config.num_memory_tokens + input_ids.shape[1], dtype=torch.long, device=device
1759
+ ).unsqueeze(0).expand(batch_size, -1)
1760
+
1761
+ else:
1762
+ # Prefill phase - process input sequence
1763
+ position_ids = torch.arange(
1764
+ mamba_inference_params.seqlen_offset, mamba_inference_params.seqlen_offset + input_ids.shape[1], dtype=torch.long, device=device
1765
+ ).unsqueeze(0).expand(batch_size, -1)
1766
+
1767
+ current_input = input_ids
1768
+
1769
+ model_kwargs = {
1770
+ 'input_ids': current_input,
1771
+ 'past_key_values': past_key_values,
1772
+ 'fla_past_key_values': fla_past_key_values,
1773
+ 'mamba_inference_params': mamba_inference_params,
1774
+ 'use_cache': True,
1775
+ 'position_ids': position_ids,
1776
+ }
1777
+
1778
+ if profiling:
1779
+ torch.cuda.synchronize()
1780
+ t1 = time.time()
1781
+
1782
+ # Forward pass for prefill
1783
+ outputs = self(**model_kwargs)
1784
+
1785
+ if mamba_inference_params is not None:
1786
+ if mamba_inference_params.seqlen_offset == 0:
1787
+ mamba_inference_params.seqlen_offset = current_input.shape[1] + self.config.num_memory_tokens
1788
+ else:
1789
+ mamba_inference_params.seqlen_offset += current_input.shape[1]
1790
+
1791
+ static_position_ids.fill_(position_ids[0, -1])
1792
+
1793
+ logits = outputs.logits[:, -1, :] # (batch_size, vocab_size)
1794
+ generated_tokens = []
1795
+
1796
+ # Generation loop using CUDA graph replay
1797
+ for step in range(max_new_tokens):
1798
+ # Sample next token using current logits
1799
+ if temperature == 0:
1800
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
1801
+ else:
1802
+ next_token = sample_token(logits, temperature=temperature, top_k=top_k, top_p=top_p)
1803
+
1804
+ generated_tokens.append(next_token)
1805
+
1806
+ # Check for EOS
1807
+ if not profiling and eos_token_id is not None and (next_token == eos_token_id).all():
1808
+ if verbose:
1809
+ print("\nEOS reached")
1810
+ break
1811
+
1812
+ # Update static tensors for graph replay
1813
+ static_current_input.copy_(next_token)
1814
+ static_position_ids.add_(1)
1815
+
1816
+ # Replay the captured graph
1817
+ generation_graph.replay()
1818
+
1819
+ if mamba_inference_params is not None:
1820
+ mamba_inference_params.seqlen_offset += 1
1821
+
1822
+ logits = static_logits.clone()
1823
+
1824
+ generated_ids = torch.cat([input_ids] + generated_tokens, dim=1)
1825
+
1826
+ if profiling:
1827
+ torch.cuda.synchronize()
1828
+ t2 = time.time()
1829
+ decode_latency = t2 - t1
1830
+ return generated_ids, decode_latency
1831
+
1832
+ return generated_ids
1833
+
1834
+
1835
+ def generate_with_cache(
1836
+ self,
1837
+ input_ids,
1838
+ max_new_tokens=128,
1839
+ temperature=1.0,
1840
+ top_k=0,
1841
+ top_p=0.9,
1842
+ eos_token_id=None,
1843
+ verbose=False,
1844
+ ):
1845
+ """
1846
+ Generate text using the hybrid model with proper cache handling using pre-initialized CUDA graph state.
1847
+
1848
+ Args:
1849
+ input_ids: Input token IDs tensor of shape (batch_size, seq_len)
1850
+ max_new_tokens: Maximum number of new tokens to generate
1851
+ temperature: Sampling temperature (0 for greedy)
1852
+ top_k: Top-k filtering (0 to disable)
1853
+ top_p: Top-p filtering (1.0 to disable)
1854
+ eos_token_id: End-of-sequence token ID
1855
+ verbose: Whether to print generated tokens
1856
+
1857
+ Returns:
1858
+ generated_ids: Tensor of shape (batch_size, input_len + generated_len)
1859
+ """
1860
+ self.eval()
1861
+ batch_size = input_ids.shape[0]
1862
+ device = input_ids.device
1863
+
1864
+ with torch.no_grad():
1865
+ max_seqlen = input_ids.shape[1] + max_new_tokens + self.config.num_memory_tokens
1866
+ past_key_values, fla_past_key_values, mamba_inference_params = self.get_init_cache(max_seqlen=max_seqlen, batch_size=batch_size)
1867
+
1868
+ for module in self.model.modules():
1869
+ if hasattr(module, 'init_kv_cache'):
1870
+ module.init_kv_cache(max_batch_size=batch_size, max_seq_len=max_seqlen)
1871
+
1872
+ # Prefill phase - process input sequence
1873
+ current_input = input_ids
1874
+ position_ids = torch.arange(
1875
+ self.model.config.num_memory_tokens + current_input.shape[1], dtype=torch.long, device=device
1876
+ ).unsqueeze(0).expand(batch_size, -1)
1877
+
1878
+ model_kwargs = {
1879
+ 'input_ids': current_input,
1880
+ 'past_key_values': past_key_values,
1881
+ 'fla_past_key_values': fla_past_key_values,
1882
+ 'mamba_inference_params': mamba_inference_params,
1883
+ 'use_cache': True,
1884
+ 'position_ids': position_ids,
1885
+ }
1886
+
1887
+ outputs = self(**model_kwargs)
1888
+
1889
+ # past_key_values, fla_past_key_values, mamba_inference_params = outputs.past_key_values
1890
+ mamba_inference_params.seqlen_offset = current_input.shape[1] + self.model.config.num_memory_tokens
1891
+
1892
+ logits = outputs.logits[:, -1, :] # (batch_size, vocab_size)
1893
+
1894
+ generated_tokens = []
1895
+
1896
+ # Generation loop
1897
+ for step in range(max_new_tokens):
1898
+ # Sample next token
1899
+ if temperature == 0:
1900
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
1901
+ else:
1902
+ next_token = sample_token(logits, temperature=temperature, top_k=top_k, top_p=top_p)
1903
+
1904
+ generated_tokens.append(next_token)
1905
+
1906
+ # Check for EOS
1907
+ if eos_token_id is not None and (next_token == eos_token_id).all():
1908
+ if verbose:
1909
+ print("\nEOS reached")
1910
+ break
1911
+
1912
+ current_input = next_token # Shape: (batch_size, 1)
1913
+
1914
+ # Update position_ids for decoding
1915
+ if position_ids is not None:
1916
+ position_ids = torch.full(
1917
+ (batch_size, 1),
1918
+ position_ids[0, -1] + 1,
1919
+ dtype=torch.long,
1920
+ device=device
1921
+ )
1922
+
1923
+ # Forward pass for next token
1924
+ model_kwargs = {
1925
+ 'input_ids': current_input,
1926
+ 'fla_past_key_values': fla_past_key_values,
1927
+ 'mamba_inference_params': mamba_inference_params,
1928
+ 'past_key_values': past_key_values,
1929
+ 'use_cache': True,
1930
+ 'position_ids': position_ids,
1931
+ }
1932
+
1933
+ outputs = self(**model_kwargs)
1934
+
1935
+ mamba_inference_params.seqlen_offset += 1
1936
+
1937
+ logits = outputs.logits[:, -1, :]
1938
+
1939
+ generated_ids = torch.cat([input_ids] + generated_tokens, dim=1)
1940
+
1941
+ return generated_ids
1942
+
1943
+
1944
+ def prepare_inputs_for_generation(
1945
+ self,
1946
+ input_ids,
1947
+ past_key_values=None,
1948
+ attention_mask=None,
1949
+ inputs_embeds=None,
1950
+ **kwargs,
1951
+ ):
1952
+ if self.config.num_memory_tokens > 0:
1953
+ attention_mask = torch.cat([torch.ones(input_ids.shape[0], self.config.num_memory_tokens, device=attention_mask.device), attention_mask], dim=1)
1954
+
1955
+ ### Note that KV cache is disable when using model.generate; Please use model.generate_with_cuda_graph or model.generate_with_cache instead.
1956
+ past_key_values = None
1957
+
1958
+ position_ids = kwargs.get("position_ids", None)
1959
+ if attention_mask is not None and position_ids is None:
1960
+ # create position_ids on the fly for batch generation
1961
+ position_ids = attention_mask.long().cumsum(-1) - 1
1962
+ position_ids.masked_fill_(attention_mask == 0, 1)
1963
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1964
+
1965
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1966
+ if inputs_embeds is not None:
1967
+ if input_ids.shape[1] == 0:
1968
+ model_inputs = {"inputs_embeds": inputs_embeds}
1969
+ else:
1970
+ inputs_embeds_new = self.model.embed_tokens(input_ids)
1971
+ model_inputs = {"inputs_embeds": torch.cat([inputs_embeds, inputs_embeds_new], dim=1)}
1972
+ else:
1973
+ model_inputs = {"input_ids": input_ids}
1974
+
1975
+ model_inputs.update(
1976
+ {
1977
+ "position_ids": position_ids,
1978
+ "past_key_values": past_key_values,
1979
+ "use_cache": kwargs.get("use_cache"),
1980
+ "attention_mask": attention_mask,
1981
+ }
1982
+ )
1983
+ return model_inputs
1984
+
1985
+
1986
+ def sample_token(logits, temperature=1.0, top_k=0, top_p=0.9):
1987
+ """
1988
+ Sample a token from logits with temperature, top-k, and top-p filtering.
1989
+
1990
+ Args:
1991
+ logits: Tensor of shape (batch_size, vocab_size)
1992
+ temperature: Sampling temperature
1993
+ top_k: Top-k filtering (0 to disable)
1994
+ top_p: Top-p filtering (1.0 to disable)
1995
+
1996
+ Returns:
1997
+ next_token: Tensor of shape (batch_size, 1)
1998
+ """
1999
+ if temperature == 0:
2000
+ return torch.argmax(logits, dim=-1, keepdim=True)
2001
+
2002
+ logits = logits / temperature
2003
+
2004
+ if top_k > 0:
2005
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
2006
+ logits.masked_fill_(indices_to_remove, float('-inf'))
2007
+
2008
+ if top_p < 1.0:
2009
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
2010
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
2011
+
2012
+ # Remove tokens with cumulative probability above the threshold
2013
+ sorted_indices_to_remove = cumulative_probs > top_p
2014
+ # Shift the indices to the right to keep also the first token above the threshold
2015
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
2016
+ sorted_indices_to_remove[..., 0] = 0
2017
+
2018
+ indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
2019
+ logits.masked_fill_(indices_to_remove, float('-inf'))
2020
+
2021
+ probs = F.softmax(logits, dim=-1)
2022
+ return torch.multinomial(probs, num_samples=1)
2023
+