File size: 18,574 Bytes
600a885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
from typing import Optional, Tuple, Union

import torch
from einops import rearrange
import torch.nn.functional as F

import triton
import triton.language as tl


@triton.jit
def rotary_kernel(

        OUT,

        X,

        COS,

        SIN,

        CU_SEQLENS,

        SEQLEN_OFFSETS,

        seqlen,

        nheads,

        rotary_dim,

        seqlen_ro,

        CACHE_KEY_SEQLEN,

        # strides

        stride_out_batch,

        stride_out_nheads,

        stride_out_seqlen,

        stride_out_headdim,

        stride_x_batch,

        stride_x_nheads,

        stride_x_seqlen,

        stride_x_headdim,

        BLOCK_K: tl.constexpr,

        IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,

        IS_VARLEN: tl.constexpr,

        INTERLEAVED: tl.constexpr,

        CONJUGATE: tl.constexpr,

        BLOCK_M: tl.constexpr,

):
    pid_m = tl.program_id(axis=0)
    pid_batch = tl.program_id(axis=1)
    pid_head = tl.program_id(axis=2)
    rotary_dim_half = rotary_dim // 2

    if not IS_VARLEN:
        X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
        OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
        COS = COS + pid_batch * seqlen_ro * rotary_dim_half
        SIN = SIN + pid_batch * seqlen_ro * rotary_dim_half
    else:
        start_idx = tl.load(CU_SEQLENS + pid_batch)
        seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
        X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
        OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads

    if pid_m * BLOCK_M >= seqlen:
        return
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    if not IS_SEQLEN_OFFSETS_TENSOR:
        rm_cs = rm + SEQLEN_OFFSETS
    else:
        rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
    rk = tl.arange(0, BLOCK_K)
    rk_half = tl.arange(0, BLOCK_K // 2)

    if not INTERLEAVED:
        # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
        X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)
        COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
        SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])
        cos = tl.load(
            COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0
        )
        sin = tl.load(
            SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0
        )
        x0 = tl.load(
            X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0
        )
        x1 = tl.load(
            X + rotary_dim_half * stride_x_headdim,
            mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
            other=0.0,
        )
        if CONJUGATE:
            sin = -sin
        o0 = x0 * cos - x1 * sin
        o1 = x0 * sin + x1 * cos
        # write back result
        OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)
        tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))
        tl.store(
            OUT + rotary_dim_half * stride_out_headdim,
            o1,
            mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),
        )
    else:
        # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
        # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
        # Loading x0 will be fast but x1 will be slow.
        # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
        # Then we do the calculation and use tl.where to pick put the right outputs for the even
        # and for the odd indices.
        rk_swap = rk + ((rk + 1) % 2) * 2 - 1  # 1, 0, 3, 2, 5, 4, ...
        rk_repeat = tl.arange(0, BLOCK_K) // 2
        X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
        X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
        COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
        SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
        cos = tl.load(
            COS,
            mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
            other=1.0,
        ).to(tl.float32)
        sin = tl.load(
            SIN,
            mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
            other=0.0,
        ).to(tl.float32)
        x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(
            tl.float32
        )
        x1 = tl.load(
            X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0
        ).to(tl.float32)
        if CONJUGATE:
            sin = -sin
        x0_cos = x0 * cos
        x1_sin = x1 * sin
        out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
        OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
        tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))


def apply_rotary(

        x: torch.Tensor,

        cos: torch.Tensor,

        sin: torch.Tensor,

        seqlen_offsets: Union[int, torch.Tensor] = 0,

        cu_seqlens: Optional[torch.Tensor] = None,

        max_seqlen: Optional[int] = None,

        interleaved=False,

        inplace=False,

        conjugate=False,

) -> torch.Tensor:
    """

    Arguments:

        x: (batch, seqlen, nheads, headdim) if cu_seqlens is None

            else (total_seqlen, nheads, headdim).

        cos: (seqlen_ro, rotary_dim / 2)

        sin: (seqlen_ro, rotary_dim / 2)

        seqlen_offsets: integer or integer tensor of size (batch,)

        cu_seqlens: (batch + 1,) or None

        max_seqlen: int

    Returns:

        y: (batch, seqlen, nheads, headdim)

    """

    batch, nheads, seqlen, headdim = x.shape

    batch_ro, seqlen_ro, rotary_dim = cos.shape

    assert batch == batch_ro
    assert sin.shape == cos.shape
    rotary_dim *= 2
    assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
    assert headdim <= 256, "Only support headdim <= 256"

    assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"

    assert (
            cos.dtype == sin.dtype
    ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
    assert (
            x.dtype == cos.dtype
    ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"

    cos, sin = cos.contiguous(), sin.contiguous()
    if isinstance(seqlen_offsets, torch.Tensor):
        assert seqlen_offsets.shape == (batch,)
        assert seqlen_offsets.dtype in [torch.int32, torch.int64]
        seqlen_offsets = seqlen_offsets.contiguous()
    else:
        assert seqlen_offsets + seqlen <= seqlen_ro

    output = torch.empty_like(x) if not inplace else x
    if rotary_dim < headdim and not inplace:
        output[..., rotary_dim:].copy_(x[..., rotary_dim:])

    BLOCK_K = (
        32
        if rotary_dim <= 32
        else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
    )
    grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads)  # noqa
    BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)

    # Need this, otherwise Triton tries to launch from cuda:0 and we get
    # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
    with torch.cuda.device(x.device.index):
        rotary_kernel[grid](
            output,  # data ptrs
            x,
            cos,
            sin,
            cu_seqlens,
            seqlen_offsets,
            seqlen,  # shapes
            nheads,
            rotary_dim,
            seqlen_ro,
            seqlen // 128,  # key for triton cache (limit number of compilations)
            output.stride(0),  # batch_strides
            output.stride(-3),  # nheads_stride
            output.stride(-2),  # seqlen_stride
            output.stride(-1),  # headdim_stride
            x.stride(0),  # batch_strides
            x.stride(-3),  # nheads stride
            x.stride(-2),  # seqlen stride
            x.stride(-1),  # headdim stride
            BLOCK_K,
            isinstance(seqlen_offsets, torch.Tensor),
            False,
            interleaved,
            conjugate,
            BLOCK_M,
        )
    return output


class ApplyRotaryEmb(torch.autograd.Function):
    @staticmethod
    def forward(

            ctx,

            x,

            cos,

            sin,

            interleaved=False,

            inplace=False,

            seqlen_offsets: Union[int, torch.Tensor] = 0,

            cu_seqlens: Optional[torch.Tensor] = None,

            max_seqlen: Optional[int] = None,

    ):
        out = apply_rotary(
            x,
            cos,
            sin,
            seqlen_offsets=seqlen_offsets,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            interleaved=interleaved,
            inplace=inplace,
        )
        if isinstance(seqlen_offsets, int):
            ctx.save_for_backward(cos, sin, cu_seqlens)  # Can't save int with save_for_backward
            ctx.seqlen_offsets = seqlen_offsets
        else:
            ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
            ctx.seqlen_offsets = None
        ctx.interleaved = interleaved
        ctx.inplace = inplace
        ctx.max_seqlen = max_seqlen
        return out if not inplace else x

    @staticmethod
    def backward(ctx, do):
        seqlen_offsets = ctx.seqlen_offsets
        if seqlen_offsets is None:
            cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
        else:
            cos, sin, cu_seqlens = ctx.saved_tensors
        # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
        # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
        if not ctx.interleaved and not ctx.inplace:
            do = do.clone()
        dx = apply_rotary(
            do,
            cos,
            sin,
            seqlen_offsets=seqlen_offsets,
            cu_seqlens=cu_seqlens,
            max_seqlen=ctx.max_seqlen,
            interleaved=ctx.interleaved,
            inplace=ctx.inplace,
            conjugate=True,
        )
        return dx, None, None, None, None, None, None, None


def apply_rotary_emb(

        x,

        cos,

        sin,

        interleaved=False,

        inplace=False,

        seqlen_offsets: Union[int, torch.Tensor] = 0,

        cu_seqlens: Optional[torch.Tensor] = None,

        max_seqlen: Optional[int] = None,

):
    """

    Arguments:

        x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None

            else (total_seqlen, nheads, headdim)

        cos, sin: (seqlen_rotary, rotary_dim / 2)

        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead

            of 1st half and 2nd half (GPT-NeoX style).

        inplace: if True, apply rotary embedding in-place.

        seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.

            Most commonly used in inference when we have KV cache.

        cu_seqlens: (batch + 1,) or None

        max_seqlen: int

    Return:

        out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None

            else (total_seqlen, nheads, headdim)

    rotary_dim must be <= headdim

    Apply rotary embedding to the first rotary_dim of x.

    """
    return ApplyRotaryEmb.apply(
        x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
    )


# For backward compatibility
apply_rotary_emb_func = apply_rotary_emb


class FastRotaryEmbedding(torch.nn.Module):
    """

    The rotary position embeddings from RoFormer_ (Su et. al).

    A crucial insight from the method is that the query and keys are

    transformed by rotation matrices which depend on the relative positions.



    Other implementations are available in the Rotary Transformer repo_ and in

    GPT-NeoX_, GPT-NeoX was an inspiration



    .. _RoFormer: https://arxiv.org/abs/2104.09864

    .. _repo: https://github.com/ZhuiyiTechnology/roformer

    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox



    If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).

    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96

    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py

    """

    def __init__(

            self,

            dim: int,

            base=10000,

            interleaved=False,

            scale_base=None,

            pos_idx_in_fp32=True,

            device=None,

    ):
        """

        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead

            of 1st half and 2nd half (GPT-NeoX style).

        pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,

            otherwise they might be in lower precision.

            This option was added because previously (before 2023-07-02), when we construct

            the position indices, we use the dtype of self.inv_freq. In most cases this would

            be fp32, but if the model is trained in pure bf16 (not mixed precision), then

            self.inv_freq would be bf16, and the position indices are also in bf16.

            Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the

            embeddings for some positions will coincide.

            To maintain compatibility with models previously trained in pure bf16,

            we add this option.

        """
        super().__init__()
        self.dim = dim
        self.base = base
        self.pos_idx_in_fp32 = pos_idx_in_fp32
        # Generate and save the inverse frequency buffer (non trainable)
        inv_freq = self._compute_inv_freq(device)
        self.register_buffer("inv_freq", inv_freq)
        self.interleaved = interleaved
        self.scale_base = scale_base
        scale = (
            (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
            if scale_base is not None
            else None
        )
        self.register_buffer("scale", scale, persistent=False)

        self._seq_len_cached = 0
        self._cos_cached = None
        self._sin_cached = None
        self._cos_k_cached = None
        self._sin_k_cached = None
        self.cos = None
        self.sin = None

    def _compute_inv_freq(self, device=None):
        return 1.0 / (
                self.base
                ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
            # ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
        )

    def _update_cos_sin_cache(self, seqlen, position_id, device=None, dtype=None):

        if (
                seqlen > self._seq_len_cached
        ):
            self._seq_len_cached = seqlen
            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
            # And the output of arange can be quite large, so bf16 would lose a lot of precision.
            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
            if self.pos_idx_in_fp32:
                t = torch.arange(seqlen, device=device, dtype=torch.float32)
                # We want fp32 here as well since inv_freq will be multiplied with t, and the output
                # will be large. Having it in bf16 will lose a lot of precision and cause the
                # cos & sin output to change significantly.
                # We want to recompute self.inv_freq if it was not loaded in fp32
                if self.inv_freq.dtype != torch.float32:
                    inv_freq = self._compute_inv_freq(device=device)
                else:
                    inv_freq = self.inv_freq
            else:
                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
                inv_freq = self.inv_freq
            freqs = torch.einsum("i,j->ij", t, inv_freq)
            if self.scale is None:
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)

            else:
                power = (
                                torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
                                - seqlen // 2
                        ) / self.scale_base
                scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
                # We want the multiplication by scale to happen in fp32
                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)

    def forward(

            self,

            q: torch.Tensor,

            k: torch.Tensor,

            position_ids: torch.Tensor,

            max_seqlen,

    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        q: (batch, nheads, seqlen, headdim)

        k: (batch, nheads, seqlen, headdim)

        position_id: (batch, seqlen)

        max_seqlen: int

        layer_id: int

            only if layer_id == 0, then update cons and sin

        Apply rotary embedding *inplace* to q k.

        """

        self._update_cos_sin_cache(max_seqlen, position_ids, device=q.device, dtype=q.dtype)
        cos, sin = F.embedding(position_ids, self._cos_cached), F.embedding(position_ids, self._sin_cached)

        q = apply_rotary_emb_func(
            q,
            cos,
            sin,
            interleaved=self.interleaved,
            inplace=True
        )
        k = apply_rotary_emb_func(
            k,
            cos,
            sin,
            interleaved=self.interleaved,
            inplace=True
        )
        return q, k