File size: 36,945 Bytes
4ae913a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
063d44d
4ae913a
 
063d44d
4ae913a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import Tensor
from typing import Optional, Tuple
import math

import logging
import copy

from dearth_config import DearthConfig

_USE_FAST_ROPE = False

class RMSNorm(torch.nn.Module): # a variant of LayerNorm that is faster and more memory efficient
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        # set the weight to be 1 initially
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

        self.register_buffer("default_pos_ids", 
                             torch.arange(0, self.max_position_embeddings, dtype=torch.long).view(-1, self.max_position_embeddings), 
                             persistent=False)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) # shape: (max_seq_len_cached, dim // 2)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    cos = cos[position_ids].unsqueeze(1)  # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
    sin = sin[position_ids].unsqueeze(1)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed




class FastRope(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        cis = precompute_freqs_cis(dim, max_position_embeddings, theta=base)
        self.register_buffer("cis", cis, persistent=False)

    def forward(self, start_idx, seq_len):
        return self.cis[start_idx:start_idx+seq_len, :]

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        end (int): End index for precomputing frequencies.
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.

    Returns:
        torch.Tensor: Precomputed frequency tensor with complex exponentials.
    """
    with torch.no_grad():
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)  # type: ignore
        freqs = torch.outer(t, freqs).float()  # type: ignore
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
        return freqs_cis      

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings.
        xk (torch.Tensor): Key tensor to apply rotary embeddings.
        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
    """
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    """
    Reshape frequency tensor for broadcasting it with another tensor.

    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
    for the purpose of broadcasting the frequency tensor during element-wise operations.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.

    Returns:
        torch.Tensor: Reshaped frequency tensor.

    Raises:
        AssertionError: If the frequency tensor doesn't match the expected shape.
        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
    """
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f"freqs_cis.shape: {freqs_cis.shape}, x.shape: {x.shape}"
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)




class AttentionMask(nn.Module):
    attn_mask: torch.Tensor = None
    def __init__(self, config: DearthConfig):
        super().__init__()
        self.config = config
        self.sliding_window_size = config.sliding_window_size
        self.front_window_size = config.front_window_size
        if self.attn_mask is None: 
            tmp_attn_mask = self.build_causal_and_window_mask(config.max_token_len, config.sliding_window_size, config.front_window_size)
            self.attn_mask = tmp_attn_mask.requires_grad_(False) # shape: (max_token_len, max_token_len)
        #self.register_buffer("attn_mask", self.build_causal_and_window_mask(config.max_token_len, config.sliding_window_size, config.front_window_size).requires_grad_(False), persistent=False)

    def forward(self, bz, n_head, q_seq_len, kv_seq_len, q_start_idx: int, device, dtype) -> torch.Tensor:
        if self.attn_mask.device != device or self.attn_mask.dtype != dtype:
            self.attn_mask = self.attn_mask.to(device=device, dtype=dtype).requires_grad_(False)
        end_idx = q_start_idx + q_seq_len
        q_k_diff_len = kv_seq_len - q_seq_len # it should be >= 0, because it is meaningless to attend future tokens
        top = q_start_idx
        bottom = end_idx
        if q_start_idx == 0 and q_k_diff_len == 0:
            # assume: sliding window size = 100, front window size = 50
            # case 1: training: q_start_idx = 0, q_seq_len = 1000, kv_seq_len = 1000
            mask = self.attn_mask[:end_idx, :end_idx]
        elif q_k_diff_len > 0 and q_start_idx > 0 and end_idx >= kv_seq_len:
            # TODO: not allow in training; remove this line after testing
            raise RuntimeError(f"NOT FOR TRAINING: q_start_idx = {q_start_idx}, q_seq_len = {q_seq_len}, kv_seq_len = {kv_seq_len}")
            if end_idx > self.front_window_size + self.sliding_window_size:
                # case 2: qsl < kvsl: q_start_idx = 190, q_seq_len = 10, kv_seq_len = 150, end_idx = 200
                # mask = self.attn_mask[top:bottom, :self.front_window_size] + \
                #     self.attn_mask[q_start_idx:end_idx, end_idx - (kv_seq_len - self.front_window_size):end_idx]
                mask = torch.cat([self.attn_mask[top:bottom, :self.front_window_size], self.attn_mask[top:bottom, end_idx - (kv_seq_len - self.front_window_size):end_idx]], dim=-1)
            elif end_idx <= self.front_window_size + self.sliding_window_size:
                # case 3: qsl < kvsl: q_start_idx = 140, q_seq_len = 10, kv_seq_len = 150, end_idx = 150
                mask = self.attn_mask[top:bottom, :end_idx]
        else:
            raise RuntimeError(f"q_start_idx = {q_start_idx}, q_seq_len = {q_seq_len}, kv_seq_len = {kv_seq_len}")
        return mask.expand(bz, n_head, q_seq_len, kv_seq_len).detach()
        
    
    @staticmethod
    def build_causal_and_window_mask(seq_len, sliding_window_size, front_window_size) -> torch.Tensor:
        mask = torch.ones(seq_len, seq_len)
        if seq_len > sliding_window_size: # need to apply sliding window mask, beacause the sequence is too long
            mask = torch.triu(mask, diagonal=-sliding_window_size+1)
            if front_window_size > 0:
                tmp_front_mask = torch.cat([torch.ones(seq_len, front_window_size), torch.zeros(seq_len, seq_len-front_window_size)], dim=-1)
                tmp_front_mask = torch.tril(tmp_front_mask, diagonal=-sliding_window_size)
                mask = mask + tmp_front_mask
        # apply causal mask
        mask = mask.tril(diagonal=0)
        mask = mask.log() # map 0 to -inf, 1 to 0
        # print(f"mask.shape: {mask.shape}, and mask")
        # print(mask)
        return mask
    

class SharedAttentionMask(nn.Module):
    def __init__(self, config: DearthConfig):
        super().__init__()
        self.config = config
        self.sliding_window_size = config.sliding_window_size
        self.front_window_size = config.front_window_size
        tmp_attn_mask = self.build_causal_and_window_mask(config.max_token_len, config.sliding_window_size, config.front_window_size)
        self.register_buffer("attn_mask", tmp_attn_mask, persistent=False)

    def forward(self, q_seq_len, kv_seq_len, q_start_idx: int) -> torch.Tensor:
        end_idx = q_start_idx + q_seq_len
        q_k_diff_len = kv_seq_len - q_seq_len # it should be >= 0, because it is meaningless to attend future tokens
        top = q_start_idx
        bottom = end_idx
        if q_start_idx == 0 and q_k_diff_len == 0:
            # assume: sliding window size = 100, front window size = 50
            # case 1: training: q_start_idx = 0, q_seq_len = 1000, kv_seq_len = 1000
            mask = self.attn_mask[:end_idx, :end_idx]
        elif q_k_diff_len > 0 and q_start_idx > 0 and end_idx >= kv_seq_len:
            # TODO: not allow in training; remove this line after testing
            raise RuntimeError(f"NOT FOR TRAINING: q_start_idx = {q_start_idx}, q_seq_len = {q_seq_len}, kv_seq_len = {kv_seq_len}")
            if end_idx > self.front_window_size + self.sliding_window_size:
                # case 2: qsl < kvsl: q_start_idx = 190, q_seq_len = 10, kv_seq_len = 150, end_idx = 200
                # mask = self.attn_mask[top:bottom, :self.front_window_size] + \
                #     self.attn_mask[q_start_idx:end_idx, end_idx - (kv_seq_len - self.front_window_size):end_idx]
                mask = torch.cat([self.attn_mask[top:bottom, :self.front_window_size], self.attn_mask[top:bottom, end_idx - (kv_seq_len - self.front_window_size):end_idx]], dim=-1)
            elif end_idx <= self.front_window_size + self.sliding_window_size:
                # case 3: qsl < kvsl: q_start_idx = 140, q_seq_len = 10, kv_seq_len = 150, end_idx = 150
                mask = self.attn_mask[top:bottom, :end_idx]
        else:
            raise RuntimeError(f"q_start_idx = {q_start_idx}, q_seq_len = {q_seq_len}, kv_seq_len = {kv_seq_len}")
        return mask.detach() # shape: (1, 1, seqlen, seqlen)
        
    
    @staticmethod
    def build_causal_and_window_mask(seq_len, sliding_window_size, front_window_size) -> torch.Tensor:
        mask = torch.ones(seq_len, seq_len)
        if seq_len > sliding_window_size: # need to apply sliding window mask, beacause the sequence is too long
            mask = torch.triu(mask, diagonal=-sliding_window_size+1)
            if front_window_size > 0:
                tmp_front_mask = torch.cat([torch.ones(seq_len, front_window_size), torch.zeros(seq_len, seq_len-front_window_size)], dim=-1)
                tmp_front_mask = torch.tril(tmp_front_mask, diagonal=-sliding_window_size)
                mask = mask + tmp_front_mask
        # apply causal mask
        mask = mask.tril(diagonal=0)
        mask = mask.log() # map 0 to -inf, 1 to 0
        # print(f"mask.shape: {mask.shape}, and mask")
        # print(mask)
        return mask



def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device=None):
    r"""
    Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
    relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
    the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
    https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292

    retrun shape: (1, num_heads, 1, sequence_length)
    """
    alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length)
    num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads))

    base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.float32, device=device)
    base = base * (alibi_bias_max / num_heads_power_of_2)

    slopes = 1.0 / torch.pow(2, base)
    slopes = slopes.view(1, num_heads, 1, 1)

    if num_heads_power_of_2 != num_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:num_heads]

    alibi = alibi * slopes
    return alibi


# def build_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device=None):
#     r"""
#     Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
#     relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
#     the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
#     https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292

#     retrun shape: (1, num_heads, 1, sequence_length)
#     """
#     slope = []
#     m_power = (-8/num_heads)
#     m_increace = -8/num_heads
#     for i in range(num_heads):
#         slope.append(m_power)
#         m_power += m_increace
#     slope = torch.tensor(slope, device=device)
#     alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length)
#     alibi = alibi * slope.view(1, num_heads, 1, 1)
#     return alibi

def compute_alibi(num_heads, sequence_length, alibi_bias_max=8, device=None):
    r"""
    Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
    relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
    the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
    https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292

    retrun shape: (1, num_heads, 1, sequence_length)
    """
    slope = []
    m_power = (-8/num_heads)
    m_increace = -8/num_heads
    for i in range(num_heads):
        slope.append(2 ** m_power)
        m_power += m_increace
    slope = torch.tensor(slope, device=device)
    alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length)
    alibi = alibi * slope.view(1, num_heads, 1, 1)
    return alibi


class Attention(nn.Module):
    def __init__(self, config: DearthConfig):
        super().__init__()
        assert config.dim % config.n_head == 0

        # regularization
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
        self.dim = config.dim
        assert config.dim % config.n_head == 0
        self.dim_qk_head = config.dim_qk_head if config.dim_qk_head is not None else config.dim // config.n_head
        self.dim_v_head = config.dim // config.n_head
        assert config.n_kv_head <= config.n_head and config.n_head % config.n_kv_head == 0
        self.n_kv_group = config.n_head // config.n_kv_head
        self.dropout_rate = config.dropout_rate

        self.alibi_emb = None
        self.pos_emb = None

        self.sliding_window_size = config.sliding_window_size

        def _fill_with_neg_inf(t):
            """FP16-compatible function that fills a tensor with -inf."""
            return t.float().fill_(float("-inf")).type_as(t)
        
        # neg_inf_mask = _fill_with_neg_inf(torch.ones_like(torch.empty(config.max_token_len, config.max_token_len)))
        # window_size_mask = torch.triu(neg_inf_mask, diagonal=1)
        # if config.sliding_window_size is not None and config.max_token_len > config.sliding_window_size:
        #     window_size_mask = window_size_mask + torch.tril(neg_inf_mask, diagonal=-config.sliding_window_size)
        # self.register_buffer("window_size_mask", window_size_mask, persistent=False)
        # if config.use_alibi:
        #     alibi_emb = compute_alibi(config.n_head, config.max_token_len) # shape: (1, n_head, 1, seqlen)
        #     #self.alibi_emb = self.alibi_emb.expand(1, config.n_head, config.max_token_len, config.max_token_len) # shape: (1, n_head, seqlen, seqlen)
        #     self.register_buffer("alibi_emb", alibi_emb, persistent=False)

        self.window_size_mask = AttentionMask(config)

        if config.use_rotary:
            if not _USE_FAST_ROPE:
                self.pos_emb = RotaryEmbedding(
                    self.dim_qk_head,
                    max_position_embeddings=config.max_token_len,
                    base=config.rope_theta,
                )
            if _USE_FAST_ROPE:
                self.pos_emb = FastRope(
                    self.dim_qk_head,
                    max_position_embeddings=config.max_token_len,
                    base=config.rope_theta,
                )

        # query, key, values projections for all heads
        self.wq = nn.Linear(self.dim, self.n_head * self.dim_qk_head, bias=True)
        self.wk = nn.Linear(self.dim, self.n_kv_head * self.dim_qk_head, bias=True)
        self.wv = nn.Linear(self.dim, self.dim // self.n_kv_group, bias=False)
        self.wo = nn.Linear(self.dim, self.dim, bias=False)

        
    def forward(self, x: Tensor, attn_mask: Tensor, start_idx: Optional[int] = 0):
        batch_size, seqlen, emb_dim = x.size() # batch size, sequence length, embedding dimensionality (dim)
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        # split embedding dim into number of heads
        xq = xq.view(batch_size, seqlen, self.n_head, self.dim_qk_head)
        xk = xk.view(batch_size, seqlen, self.n_kv_head, self.dim_qk_head)
        xv = xv.view(batch_size, seqlen, self.n_kv_head, self.dim_v_head)

        if self.pos_emb is not None and _USE_FAST_ROPE:
            xq, xk = apply_rotary_emb(xq, xk, self.pos_emb(start_idx, seqlen))

        # transpose to get dimensions batch_size * n_head * seqlen * emb_dim
        xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
        kv_seqlen = xk.size(2)

        # apply positional embeddings
        if self.pos_emb is not None and not _USE_FAST_ROPE:
            # self.pos_emb = self.pos_emb.to(x.device, dtype=x.dtype)
            # xq, xk = apply_rotary_pos_emb(xq, xk, self.pos_emb[start_idx:start_idx+seqlen])
            cos, sin = self.pos_emb(xv, seq_len=kv_seqlen)
            xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin, self.pos_emb.default_pos_ids[:, :kv_seqlen])

        # TODO: add cache for fast inference


        # grouped query
        xk = repeat_kv(xk, self.n_kv_group)
        xv = repeat_kv(xv, self.n_kv_group)

        # self.window_size_mask = self.window_size_mask.to(x.device, dtype=x.dtype)
        # attn_mask = self.window_size_mask[start_idx:start_idx+seqlen, start_idx:start_idx+kv_seqlen]
        # attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) # shape: (1, 1, seqlen, seqlen)
        # attn_mask = attn_mask.expand(batch_size, self.n_head, seqlen, kv_seqlen) # shape: (batch_size, n_head, seqlen, seqlen)
        # if self.alibi_emb is not None:
        #     self.alibi_emb = self.alibi_emb.to(x.device, dtype=x.dtype)
        #     attn_mask = attn_mask + self.alibi_emb[:,:,:,:kv_seqlen]

        #attn_mask = self.window_size_mask(batch_size, self.n_head, seqlen, kv_seqlen, start_idx, x.device, x.dtype) # -inf or 0
        
        # efficient attention using Flash Attention CUDA kernels
        y = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout_rate if self.training else 0)
        y = y.transpose(1, 2).contiguous().view(batch_size, seqlen, emb_dim) # merge heads

        # output projection
        return self.wo(y)

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    hidden_states.shape = (batch, n_kv_head, seqlen, head_dim)
    """
    # if n_rep == 1:
    #     return hidden_states
    # return torch.repeat_interleave(hidden_states, n_rep, dim=1)

    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

# def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
#     """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
#     bs, slen, n_kv_heads, head_dim = x.shape
#     if n_rep == 1:
#         return x
#     return (
#         x[:, :, :, None, :]
#         .expand(bs, slen, n_kv_heads, n_rep, head_dim)
#         .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
#     )

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        dim = config.dim
        hidden_dim = config.dim * 4 if config.hidden_dim is None else config.hidden_dim
        multiple_of = 64 if config.multiple_of is None else config.multiple_of
        hidden_dim = int(2 * hidden_dim / 3)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) # round up to nearest multiple of multiple_of

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class Mimic_Attn(Attention):
    def __init__(self, config):
        new_config = copy.deepcopy(config)
        new_config.n_head = config.mimic_n_head if config.mimic_n_head is not None else config.n_head
        new_config.n_kv_head = config.mimic_n_kv_head if config.mimic_n_kv_head is not None else config.n_kv_head
        new_config.dim_qk_head = config.mimic_dim_qk_head if config.mimic_dim_qk_head is not None else config.dim_qk_head
        new_config.dropout_rate = config.mimic_attn_dropout if config.mimic_attn_dropout is not None else 0.0
        new_config.use_rotary = config.mimic_use_rotary if config.mimic_use_rotary is not None else config.use_rotary
        new_config.use_alibi = config.mimic_use_alibi if config.mimic_use_alibi is not None else config.use_alibi

        super().__init__(new_config)
        self.saved_q = None
        self.saved_k = None
        self.saved_v = None
        self.saved_attn_map = None

    def forward(self, x: Tensor, attn_mask: Tensor, start_idx: Optional[int] = 0): # shape of attn_mask: (bz, n_head, q_seq_len, kv_seq_len)
        batch_size, seqlen, emb_dim = x.size() # batch size, sequence length, embedding dimensionality (dim)
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        self.saved_v = xv

        # split embedding dim into number of heads
        xq = xq.view(batch_size, seqlen, self.n_head, self.dim_qk_head)
        xk = xk.view(batch_size, seqlen, self.n_kv_head, self.dim_qk_head)
        xv = xv.view(batch_size, seqlen, self.n_kv_head, self.dim_v_head)

        if self.pos_emb is not None and _USE_FAST_ROPE:
            xq, xk = apply_rotary_emb(xq, xk, self.pos_emb(start_idx, seqlen))

        # transpose to get dimensions batch_size * n_head * seqlen * emb_dim
        xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
        kv_seqlen = xk.size(2)

        # # apply positional embeddings
        # if self.pos_emb is not None:
        #     self.pos_emb = self.pos_emb.to(x.device)
        #     xq, xk = apply_pos_emb(xq, xk, self.pos_emb[start_idx:start_idx+seqlen])
        if self.pos_emb is not None and not _USE_FAST_ROPE:
            cos, sin = self.pos_emb(xv, seq_len=kv_seqlen)
            xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin, self.pos_emb.default_pos_ids[:, :kv_seqlen])

        # TODO: add cache for fast inference

        # grouped query
        xk = repeat_kv(xk, self.n_kv_group)
        xv = repeat_kv(xv, self.n_kv_group)

        # self.window_size_mask = self.window_size_mask.to(x.device)
        # kv_seqlen = xk.size(2)
        # attn_mask = self.window_size_mask[start_idx:start_idx+seqlen, start_idx:start_idx+kv_seqlen]
        # attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) # shape: (1, 1, seqlen, seqlen)
        # attn_mask = attn_mask.expand(batch_size, self.n_head, seqlen, kv_seqlen) # shape: (batch_size, n_head, seqlen, seqlen)
        # if self.alibi_emb is not None:
        #     self.alibi_emb = self.alibi_emb.to(x.device)
        #     attn_mask = attn_mask + self.alibi_emb[:,:,:,:kv_seqlen]

        #attn_mask = self.window_size_mask(batch_size, self.n_head, seqlen, kv_seqlen, start_idx, x.device, x.dtype) # -inf or 0

        attn_weights = torch.matmul(xq, xk.transpose(2, 3)) * (1 / math.sqrt(self.dim_qk_head)) # shape: (batch_size, n_head, seqlen, seqlen)
        attn_weights = attn_weights + attn_mask.expand(batch_size, self.n_head, seqlen, kv_seqlen) # shape: (batch_size, n_head, seqlen, seqlen
        attn_weights = F.softmax(attn_weights.float(), dim=-1).to(xq.dtype) # shape: (batch_size, n_head, seqlen, seqlen)
        # use log_softmax to avoid overflow
        #attn_weights = F.log_softmax(attn_weights, dim=-1).exp() # shape: (batch_size, n_head, seqlen, seqlen)
        self.saved_attn_map = attn_weights

        attn_weights = F.dropout(attn_weights, p=self.dropout_rate, training=self.training)

        y = torch.matmul(attn_weights, xv) # shape: (batch_size, n_head, seqlen, head_dim)

        y = y.transpose(1, 2).contiguous().view(batch_size, seqlen, emb_dim) # merge heads

        # output projection
        return self.wo(y)

    def get_intermediate_attn_v(self):
        return self.saved_attn_map, self.saved_v


class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = RMSNorm(config.dim)
        self.attn = Attention(config)
        self.ln_2 = RMSNorm(config.dim)
        self.mlp = MLP(config)

        self.residual_factor = config.residual_factor

    def forward(self, x: Tensor, attn_mask: Tensor, start_idx: int):
        # post-LN
        residual = x
        x = self.attn(x, attn_mask, start_idx=start_idx)
        x = self.ln_1(self.residual_connection(x, residual))

        residual = x
        x = self.mlp(x)
        x = self.ln_2(self.residual_connection(x, residual))

        return x
    
    def residual_connection(self, x, residual):
        # residual factor should > 1.0
        return residual * self.residual_factor + x



class DearthModel(nn.Module):
    def __init__(self, config: DearthConfig):
        super().__init__()
        assert config.vocab_size is not None
        assert config.max_token_len is not None

        self.layer_init_factor = config.layer_init_factor if config.layer_init_factor is not None else float(config.n_layer * 8) ** (-1/2)
        self.residual_factor = config.residual_factor if config.residual_factor is not None else float(config.n_layer * 2) ** (1/4)
        if config.residual_factor is None:
            config.residual_factor = self.residual_factor
            #logging.warning(f"residual_factor is not set, using default value {self.residual_factor} = (2 * n_layer) ** 1/4")
        if config.layer_init_factor is None:
            config.layer_init_factor = self.layer_init_factor
            #logging.warning(f"layer_init_factor is not set, using default value {self.layer_init_factor} = (n_layer * 8) ** -1/2")
        
        self.config = config

        layers = []
        for i in range(config.n_layer):
            if config.mimic_attn_layer is not None and i+1 == config.mimic_attn_layer:
                new_layer = TransformerBlock(config)
                new_layer.attn = Mimic_Attn(config)
                layers.append(new_layer)
            else:
                layers.append(TransformerBlock(config))

        self.layers = nn.ModuleList(layers)
        self.embed_tokens = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
        self.ln_before = RMSNorm(config.dim)
        self.shared_attn_mask = SharedAttentionMask(config)

        if config.mimic_attn_layer is not None and config.mimic_attn_layer > 0 and config.mimic_attn_layer <= config.n_layer:
            self.mimic_attn = self.layers[config.mimic_attn_layer-1].attn
        else:
            self.mimic_attn = None

        # initialize weights
        _init_weight(self, self.layer_init_factor)

    def get_input_device(self):
        return self.embed_tokens.weight.device

    # def _init_weights(self, module):
    #     if isinstance(module, nn.Linear):
    #         torch.nn.init.xavier_normal_(module.weight, gain=self.layer_init_factor)
    #         if module.bias is not None:
    #             torch.nn.init.zeros_(module.bias)
    #     elif isinstance(module, nn.Embedding):
    #         torch.nn.init.xavier_normal_(module.weight, gain=1)
    #     elif isinstance(module, RMSNorm):
    #         module.weight.data.fill_(1.0)
            


    def forward(self, tokens, start_idx=0): # return all logits
        batch_size, seqlen = tokens.size()
        if seqlen > self.config.max_token_len:
            raise ValueError(f"input sequence length {seqlen} exceeds maximum sequence length {self.config.max_token_len}")

        # create token embeddings from token table; x.shape = (batch_size, seqlen, dim)
        h = self.embed_tokens(tokens)
        assert h.size() == (batch_size, seqlen, self.config.dim)

        h = self.ln_before(h)

        # transformer layers
        attn_mask = self.shared_attn_mask(seqlen, seqlen, q_start_idx=start_idx) # TODO: it will not work if q_seq_len != kv_seq_len
        for layer in self.layers:
            h = layer(h, attn_mask, start_idx=start_idx) # h.shape = (batch_size, seqlen, dim)

        return h, None


    def get_num_params(self):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        #n_params = sum(p.numel() for p in self.parameters())
        n_params = sum(p.numel() for p in self.transformer.layers[0].parameters() if p.requires_grad)
        return int(n_params)
    
    
    def get_intermediate_attn_v(self):
        if self.mimic_attn is None:
            return torch.zeros(1, 1, 1, 1), torch.zeros(1, 1, 1, 1)
        return self.mimic_attn.get_intermediate_attn_v()


class DearthForCausalLM(nn.Module):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config: DearthConfig):
        super().__init__()
        self.model = DearthModel(config)
        self.dearth_config = config
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
        torch.nn.init.xavier_normal_(self.lm_head.weight, gain=1)

        self.front_window_size = config.front_window_size
        self.sliding_window_size = config.sliding_window_size

    def get_input_device(self):
        return self.model.get_input_device()
    
    def get_intermediate_attn_v(self):
        return self.model.get_intermediate_attn_v()
    
    def print_all_params(self):
        for name, param in self.named_parameters():
            print(f"name: {name}, param.shape: {param.shape}")

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        use_cache: Optional[bool] = False,
    ) ->Tuple: #-> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, MistralForCausalLM

        >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        outputs = self.model(
            tokens=input_ids
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

        output = (logits,) + outputs[1:]
        return output


def _init_weight(model, weight_init_factor): # TODO: fix this part if change any model structure
    small_list = {'wv', 'wo', 'w1', 'w2', 'w3'}
    norm_list = {'ln_before', 'ln_2', 'ln_1'}
    for name, p in model.named_parameters():
        percise_name = name.split(".")[-2]
        if "bias" in name:
            logging.debug(f"the parameter {name} is initialized with 0.0")
            p.data.fill_(0.0)
        elif percise_name in small_list:
            logging.debug(f"the parameter {name} is initialized with gain={weight_init_factor}")
            torch.nn.init.xavier_normal_(p, gain=weight_init_factor)
        elif percise_name in norm_list:
            logging.debug(f"the parameter {name} is initialized with 1.0")
            p.data.fill_(1.0)
        else:
            logging.debug(f"the parameter {name} is initialized with gain=1.0")
            torch.nn.init.xavier_normal_(p, gain=1)