File size: 15,067 Bytes
d90b3a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from mamba_ssm.ops.selective_scan_interface import (
        selective_scan_ref,
        selective_scan_fn,
        mamba_inner_fn,
    )
    from causal_conv1d import causal_conv1d_fn
    import einops
except ModuleNotFoundError:
    print(
        "Unable to import Mamba kernels. Install them from our requirements/requirements-mamba.txt, \
    or directly from https://github.com/state-spaces/mamba"
    )
    pass

from megatron.model.norms import get_norm
from megatron import mpu

# Mamba sublayer, with tensor parallelism
class ParallelMambaBlock(nn.Module):
    def __init__(
        self,
        neox_args,
        init_method,
        output_layer_init_method,
    ):
        super().__init__()

        self.neox_args = neox_args

        dtype = {
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
            "fp32": torch.float32,
        }[neox_args.precision]
        self.precision = dtype
        factory_kwargs = {"device": torch.cuda.current_device(), "dtype": dtype}

        assert not (
            neox_args.mamba_use_bias_in_linears and neox_args.mamba_inner_func_fusion
        ), "Mamba fused inner fn and bias in x_proj not compatible!"

        assert (
            neox_args.intermediate_size == None or neox_args.expansion_factor == None
        ), "Must pass either the absolute intermediate size or the relative expansion factor for the mamba projections"

        # set variables, mostly following mamba defaults
        self.d_model = neox_args.hidden_size
        self.d_state = 16  # state dimensions per channel
        self.d_conv = 4  # convolution width
        if neox_args.intermediate_size:
            self.d_inner = neox_args.intermediate_size
        else:
            self.expand = (
                neox_args.expansion_factor if neox_args.expansion_factor else 2
            )
            self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16)  # rank of dt / Delta parameter
        self.dt_scale = 1.0

        self.dt_init = "random"
        self.dt_min, self.dt_max, self.dt_init_floor = 0.001, 0.1, 1e-4
        assert self.dt_init in ["constant", "random"]

        # TP-specific setup
        world_size = mpu.get_model_parallel_world_size()
        self.d_inner_per_rank = mpu.divide(self.d_inner, world_size)

        if neox_args.mamba_inner_func_fusion and world_size > 1:
            # as with gpt-j residual, we must manually reduce output from final proj
            # across TP ranks, since it is not done by fused mamba_inner_fn .
            self.reduce = mpu.mappings.reduce_from_model_parallel_region

        # up-projection.
        self.in_proj = mpu.ColumnParallelLinear(
            neox_args=neox_args,
            input_size=self.d_model,
            output_size=self.d_inner * 2,
            gather_output=False,
            init_method=init_method,
            skip_bias_add=not neox_args.mamba_use_bias_in_linears,
            bias=neox_args.mamba_use_bias_in_linears,
        )

        # convolution (parallelized across d_inner)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner_per_rank,
            out_channels=self.d_inner_per_rank,
            bias=neox_args.mamba_use_bias_in_conv,
            kernel_size=self.d_conv,
            groups=self.d_inner_per_rank,
            padding=self.d_conv - 1,
            **factory_kwargs,
        )
        # Conv bias sometimes in 32-bit erroneously, when holding other parameters in fp32.
        # Uncertain why
        self.conv1d.to(self.precision)

        self.act_fn = F.silu  # we do not allow for other activation fns

        # x_proj corresponds to s_B(x), s_C(x), s_Delta(x)
        # in https://arxiv.org/pdf/2312.00752.pdf Algorithm 2
        # (computes data-dependent B, C, Delta/dt)
        self.x_proj = mpu.RowParallelLinear(
            neox_args=neox_args,
            input_size=self.d_inner,
            output_size=self.dt_rank + self.d_state * 2,
            input_is_parallel=True,
            init_method=init_method,
            skip_bias_add=not neox_args.mamba_use_bias_in_linears,
            parallel_output=True,
            bias=neox_args.mamba_use_bias_in_linears,
        )

        # up-project dt / Delta from dt_rank to d_inner
        # dt_proj 's bias is a special case and should be kept always turned on -- Alg. 2 in the Mamba paper (https://arxiv.org/abs/2312.00752)
        # defines Delta as Delta = Tau_{Delta}(Parameter + s_{Delta}(x)) where s_{Delta}(x) = Broadcast_{D}(Linear_{1}(x))
        # or as they further explain in section 3.6 can be also s_{Delta}(x) = Linear_{D}(Linear_{R}(x)) where Linear_R
        # is the delta portion of x_proj and Linear_D is the dt_proj weight. Then, the Parameter term from Alg. 2 can
        # be viewed as the bias term in dt_proj, with a special initialization from https://arxiv.org/abs/2206.12037
        self.dt_proj = nn.Linear(
            self.dt_rank, self.d_inner_per_rank, bias=True, **factory_kwargs
        )

        # special init for dt_proj
        dt_init_std = (self.dt_rank**-0.5) * self.dt_scale
        if self.dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif self.dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # more dt_proj init stuff. copied from https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/modules/mamba_simple.py#L91-L101
        dt = torch.exp(
            torch.rand(self.d_inner_per_rank, **factory_kwargs)
            * (math.log(self.dt_max) - math.log(self.dt_min))
            + math.log(self.dt_min)
        ).clamp(min=self.dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)

        # initialize A . uses S4D real initialization
        A = einops.repeat(
            torch.arange(
                1,
                self.d_state + 1,
                dtype=torch.float32,
                device=torch.cuda.current_device(),
            ),
            "n -> d n",
            d=self.d_inner_per_rank,
        ).contiguous()
        A_log = torch.log(A).to(
            torch.float32
        )  # Keep in fp32, following https://github.com/state-spaces/mamba#precision and code comments
        self.A_log = nn.Parameter(A_log)
        self.A_log._no_weight_decay = (
            True  # setting this attribute turns off weight decay for this param
        )
        # setting this attribute prevents deeperspeed from casting this param to fp32
        # requires DeepersSpeed commit https://github.com/EleutherAI/DeeperSpeed/commit/6d097beccc4e3b0ac806c7d975f8c10d4689de26 or later
        if self.neox_args.mamba_selective_fp32_params:
            self.A_log._deepspeed_no_cast = True

        # D parameter
        self.D = nn.Parameter(
            torch.ones(
                self.d_inner_per_rank,
                device=torch.cuda.current_device(),
                dtype=torch.float32,
            )
        ).to(
            torch.float32
        )  # Keep in fp32, following https://github.com/state-spaces/mamba#precision and code comments
        self.D._no_weight_decay = (
            True  # setting this attribute turns off weight decay for this param
        )
        # setting this attribute prevents deeperspeed from casting this param to fp32
        # requires DeeperSpeed commit https://github.com/EleutherAI/DeeperSpeed/commit/6d097beccc4e3b0ac806c7d975f8c10d4689de26 or later
        if self.neox_args.mamba_selective_fp32_params:
            self.D._deepspeed_no_cast = True

        # out down-projection.
        # use "single_residual_scaled_normal"
        # for output_layer_init_method
        # to perform gpt-2 style scaled init as done in Mamba paper.
        self.out_proj = mpu.RowParallelLinear(
            neox_args=neox_args,
            input_size=self.d_inner,
            output_size=self.d_model,
            input_is_parallel=True,
            init_method=output_layer_init_method,
            skip_bias_add=not neox_args.mamba_use_bias_in_linears,
            bias=neox_args.mamba_use_bias_in_linears,
            parallel_output=False,
        )

    def selective_scan(
        self,
        x,
        dt,
        A,
        B,
        C,
        D,
        z=None,
        delta_bias=None,
        delta_softplus=True,
    ):

        if not self.neox_args.mamba_selective_scan_fusion:
            y = selective_scan_ref(
                u=x,
                delta=dt,
                A=A,
                B=B,
                C=C,
                D=D,
                z=z,
                delta_bias=delta_bias,
                delta_softplus=delta_softplus,
                return_last_state=False,
            )
        else:
            y = selective_scan_fn(
                x,
                dt,
                A,
                B,
                C,
                D=D,
                z=z,
                delta_bias=delta_bias,
                delta_softplus=delta_softplus,
                return_last_state=False,
            )

        return y

    def forward(self, hidden_states):
        """ """
        # TODO: support inference natively in neox.
        # For now, we only handle training (parallel scan).
        assert self.training, "Mamba in NeoX does not support inference!"

        # hidden_states: [sq, b, h]
        seqlen, batch, dim = hidden_states.shape

        # first up: perform in_proj
        xz, _ = self.in_proj(hidden_states)
        xz = einops.rearrange(xz, "l b d -> b d l")

        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)

        if self.neox_args.mamba_inner_func_fusion:
            # =================
            # Fused mamba inner
            # =================

            # mamba provides a mamba_inner fn that computes the entire (post-in_proj) Mamba block.
            # we want to use it if we can, as it saves memory and provides speedups.
            # equivalent to use_fast_path=True in state-spaces/mamba.
            out = mamba_inner_fn(
                xz,
                self.conv1d.weight,
                # for some bizarre reason this becomes fp32 sometime after init, when A and D held in fp32.
                # cast it manually if the bias exists
                self.conv1d.bias.to(self.precision)
                if self.conv1d.bias is not None
                else self.conv1d.bias,
                self.x_proj.weight,
                self.dt_proj.weight,
                self.out_proj.weight,
                self.out_proj.bias,
                A,
                None,  # B is input-dependent, will compute from x_proj
                None,  # C is input-dependent, will compute from x_proj
                self.D.float(),
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
            )
            if getattr(self, "reduce", None):
                # manually reduce after mamba_inner_fn
                # to collect outputs from different TP ranks.
                # handled by running self.out_proj(y) below
                # so only needed here.
                out = self.reduce(out)

            out = einops.rearrange(out, "b l h -> l b h")

            return out

        x, z = xz.chunk(2, dim=1)

        # ===========
        # Convolution
        # ===========

        if not self.neox_args.mamba_causal_conv_fusion:
            self.conv1d.to(self.precision)  # required if keeping fp32 A_log, D
            x = self.act_fn(self.conv1d(x)[..., :seqlen])
        else:
            # Note: this requires silu as activation.
            x = causal_conv1d_fn(
                x=x,
                weight=einops.rearrange(self.conv1d.weight, "d 1 w -> d w"),
                bias=self.conv1d.bias.to(self.precision)
                if self.conv1d.bias is not None
                else self.conv1d.bias,
                activation="silu",
            )

        # ==============
        # SSM (S6) layer
        # ==============

        # project: perform s_B, s_C, s_Delta projections
        x_dbl, _ = self.x_proj(einops.rearrange(x, "b d l -> (b l) d"))
        # split into component dt, B, C
        dt, B, C = torch.split(
            x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
        )

        # up-project Delta / dt
        dt = self.dt_proj.weight @ dt.t()
        dt = einops.rearrange(dt, "d (b l) -> b d l", l=seqlen)

        # rearrange B, C
        B = einops.rearrange(B, "(b l) d_state -> b d_state l", l=seqlen).contiguous()
        C = einops.rearrange(C, "(b l) d_state -> b d_state l", l=seqlen).contiguous()

        # perform selective scan.
        y = self.selective_scan(
            x,
            dt,
            A,
            B,
            C,
            self.D.float(),
            z=z,
            delta_bias=self.dt_proj.bias.float(),
            delta_softplus=True,
        )

        # ===============
        # Down-Projection
        # ===============
        y = einops.rearrange(y, "b d l -> b l d")

        out, _ = self.out_proj(y)

        out = einops.rearrange(out, "b l h -> l b h")

        return out


class ParallelMambaResidualLayer(nn.Module):
    """
    Pre-norm Mamba Block with residual connection. No parallelism yet supported.
    """

    def __init__(
        self,
        neox_args,
        init_method,
        output_layer_init_method,
        layer_number,
    ):
        super().__init__()
        # TODO: allow for residual in fp32 if it helps?
        self.layer_number = layer_number

        # TODO: Add support for triton RMSNorm fused kernel at https://github.com/state-spaces/mamba/blob/v1.2.0/mamba_ssm/ops/triton/layernorm.py
        norm, eps = get_norm(neox_args)

        self.norm = norm(neox_args.hidden_size, eps=eps)

        self.mixer = ParallelMambaBlock(
            neox_args=neox_args,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
        )

    def forward(self, x, attention_mask=None, layer_past=None):

        # pseudocode:
        # x = x + mixer(norm(x))
        residual = x

        hidden_states = self.mixer(self.norm(x))

        return hidden_states + residual


class ParallelMambaResidualLayerPipe(ParallelMambaResidualLayer):
    """Extends MambaResidualLayer to forward attention_mask through the pipeline. DeepSpeed requires this."""

    def forward(self, args):
        assert (
            len(args) == 2
        ), "MambaResidualLayerPipe expects 2 arguments - hidden_states and attention_mask"
        hidden_states, attention_mask = args
        # we are returning just [hidden_states, mask]
        return super().forward(hidden_states, attention_mask), attention_mask