MLA: Redefining KV-Cache Through Low-Rank Projections and On-Demand Decompression

Community Article Published February 4, 2025

Introduction

As Large Language Models (LLMs) continue to thrive, hardware resources remain a daunting “ceiling”—especially with limited GPU memory (VRAM). The question of how to achieve extended context lengths and faster inference under constrained resources has long been a key focus for both engineering and research communities. Aside from the common techniques of quantization and pruning, there is a growing emphasis on “reducing the KV-Cache footprint at inference time”.

This article first revisits how MHA (Multi-Head Attention), MQA (Multi-Query Attention), and GQA (Grouped-Query Attention) handle or reduce K/V storage, and then centers on the approach introduced by DeepSeek, MLA (Multi-Head Latent Attention). Unlike those earlier methods that work mainly at the level of “K/V sharing,” MLA employs a combination of low-rank projection and on-demand decompression, allowing us to bypass storing multi-head K/V directly. MLA uses “latent vectors” and further leverages a matrix-merging trick so that, during inference, the attention mechanism only requires minimal VRAM usage to function.

It is worth noting that MLA, when deployed in real systems, often needs to accommodate RoPE (Rotary Position Embedding). To keep things clear, we will first explain the core MLA idea (low-rank projection) and then discuss how to integrate RoPE. We hope this structured approach provides insight into the reasoning and nuances behind MLA’s design.

Special Thanks: Part of the inspiration in this article comes from Su Jianlin’s blog post 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA (in Chinese). We extend our respect to him for his work.


1. Why Reduce the KV-Cache?

1.1 The “Invisible Bottleneck” in Long-Context Inference

During autoregressive generation with a Transformer, each newly generated token references the historical Key/Value ((K, V)) vectors from all previous tokens. These stored Key/Value vectors constitute the KV-Cache during inference. If the sequence length is (L), with (h) attention heads, and each head has dimensionality (d_k) or (d_v), then the total KV-Cache scales roughly as L×h×(dk+dv), L \times h \times (d_k + d_v), growing linearly with LL. Once the context reaches thousands or even tens of thousands of tokens, KV-Cache can dominate VRAM usage, surpassing even the model’s own activation storage.

1.2 Constraints of VRAM and Bandwidth

When the sequence is very long, fitting all KV-Cache on a single GPU becomes unfeasible. Splitting the cache across multiple GPUs or machines leads to bandwidth bottlenecks—inter-device communication is typically much slower than on-device memory access. Simply put, if we can handle longer contexts on fewer GPUs, we minimize communication overhead and improve throughput. This is precisely why MQA, GQA, and MLA have emerged and keep evolving.


2. MHA → MQA → GQA: Simplifying K/V in Multi-Head Attention

Before introducing MLA, let’s briefly overview conventional Multi-Head Attention (MHA) as well as the sharing-based approaches MQA and GQA, which aim to reduce K/V storage. This context sets the stage for comparing MLA to prior work.

2.1 Multi-Head Attention (MHA) Foundations

2.1.1 Classic Attention Formulas

In a Transformer, a token sequence x1,,xl\mathbf{x}_1,\dots,\mathbf{x}_l is projected into multiple sets of (Q,K,V)(Q, K, V) for attention computation. For the (s)-th head, assuming hidden dimensionality dd: qi(s)=xiWq(s),ki(s)=xiWk(s),vi(s)=xiWv(s). \mathbf{q}_i^{(s)} = \mathbf{x}_i \mathbf{W}_q^{(s)},\quad \mathbf{k}_i^{(s)} = \mathbf{x}_i \mathbf{W}_k^{(s)},\quad \mathbf{v}_i^{(s)} = \mathbf{x}_i \mathbf{W}_v^{(s)}. Under autoregressive decoding, the attention score at step (t) is often written as αt,i(s)=qt(s)ki(s),for it. \alpha_{t,i}^{(s)} = \mathbf{q}_t^{(s)} \,\mathbf{k}_i^{(s)\top}, \quad\text{for } i \le t. To speed up inference, we cache the computed ki(s)\mathbf{k}_i^{(s)} and vi(s)\mathbf{v}_i^{(s)} in VRAM for later tokens, a storage referred to as the KV-Cache.

2.1.2 Pressures on VRAM

Because MHA typically retains distinct K/V for each head, if (h) is large, you end up storing (h) sets of Key/Value, which can quickly blow up VRAM usage. Researchers therefore wondered: can we make multi-head attention share or compress these K/V representations?

2.2 MQA: Extreme K/V Sharing

MQA (Multi-Query Attention) focuses on letting every head share a single K/V pair: ki=xiWk,vi=xiWv, \mathbf{k}_i = \mathbf{x}_i \mathbf{W}_k,\quad \mathbf{v}_i = \mathbf{x}_i \mathbf{W}_v, while each head still retains its own (\mathbf{q}_i^{(s)}). Then, the KV-Cache is just 1 set of K/V instead of (h). This slash in VRAM usage can be as large as a factor of (1/h). Implementations like PaLM or StarCoder have adopted MQA. However, because all heads share the same K/V, certain tasks might see degraded performance unless additional training strategies are used.

2.3 GQA: Grouping Heads

If MQA feels overly aggressive, GQA (Grouped-Query Attention) provides a middle ground: group the (h) heads into (g) clusters, each cluster sharing one set of K/V. Hence, the KV-Cache shrinks to (g) sets (rather than (h)). Examples include LLaMA2-70B and ChatGLM2/3. GQA retains greater variety than MQA but still saves VRAM over standard MHA.

2.4 Comparison of MHA / MQA / GQA

Method KV-Cache Storage K/V Sharing? VRAM Savings Head Diversity
MHA Store (h) copies Independent per Head Low (Baseline) High
MQA Store 1 copy Fully Shared K/V High Lower
GQA Store (g) copies Shared by groups Moderate Fairly High

Whether MQA or GQA, they both revolve around the question of “how much to share K/V across heads.” In contrast, MLA rethinks “what we actually store at inference time”: It shifts most of the K/V content into a latent vector, reconstructing it only on demand. Let’s explore how MLA uses low-rank projection and on-demand decompression, initially without worrying about RoPE.


3. The MLA Core: Low-Rank Projections and On-Demand Reconstruction (Without RoPE)

3.1 Key Idea: Switching from “Store Multi-Head K/V” to “Store a Low-Dimensional Latent Vector”

In MLA (Multi-Head Latent Attention), we still project each input into Key and Value at training time, but we introduce a low-rank latent vector (\mathbf{c}_i). During inference, instead of caching high-dimensional multi-head K/V, we only store this compact (\mathbf{c}_i), then merge matrices when needed to reconstruct the multi-head K/V. Concretely, ignoring RoPE for the moment, we can represent MLA’s training step like so:

ci=xiWc(a low-rank projection, WcRd×dc), \mathbf{c}_i = \mathbf{x}_i \mathbf{W}_c \quad (\text{a low-rank projection, } \mathbf{W}_c \in \mathbb{R}^{d \times d_c}),

and for each head(s)we define projection matrices(\mathbf{W}_{kc}^{(s)}, \mathbf{W}_v^{(s)})such that ki(s)=ciWkc(s),vi(s)=ciWv(s). \mathbf{k}_i^{(s)} = \mathbf{c}_i \mathbf{W}_{kc}^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{c}_i \mathbf{W}_v^{(s)}. As a result, no matter how many heads we have, only training time sees explicit multi-head K/V. At inference, we simply cache the latent vector (\mathbf{c}_i), reconstructing K/V on-the-fly using matrix combinations. This is how MLA largely decouples the KV-Cache cost from the number of heads.

3.2 On-Demand Decompression: How VRAM is Saved

During inference, whenever we generate token (t) and need to compute the dot product (\mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top}) for previous tokens (i < t), a conventional approach would read (\mathbf{k}_i^{(s)}) from VRAM. MLA, however, merges them:

ki(s)=ciWkc(s)qt(s)ki(s)=(xtWq(s))(ciWkc(s)). \mathbf{k}_i^{(s)} = \mathbf{c}_i\,\mathbf{W}_{kc}^{(s)}\quad\Longrightarrow\quad\mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top}=(\mathbf{x}_t \mathbf{W}_q^{(s)})(\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top.

Thanks to properties of matrix multiplication, we can further do some merging as:

(xtWq(s))(ciWkc(s))=xt(Wq(s)Wkc(s))(ci)=xtWmerged(s)(ci) (\mathbf{x}_t \mathbf{W}_q^{(s)}) (\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top = \mathbf{x}_t (\mathbf{W}^{(s)}_q \mathbf{W}^{(s)\top}_{kc}) (\mathbf{c}_i)^\top = \mathbf{x}_t \mathbf{W}_{\mathrm{merged}}^{(s)} (\mathbf{c}_i)^\top

Hence, we only keep (\mathbf{c}_i) in VRAM, never storing all the (\mathbf{k}_i^{(s)}).
If (\mathbf{c}_i) is(d_c)-dimensional, with(d_c \ll h \times d_k), then the KV-Cache cost transitions from(h \times d_k)down to(d_c).

3.3 How Low-Rank Projection Drastically Reduces Storage

One might ask, how big can the compression ratio get? Suppose (d=4096), (h=32), and single-head dimension (d_k=128). In standard MHA, each token’s Key is(32 \times 128 = 4096)elements (similarly for Value). MLA might set the latent vector (\mathbf{c}_i) to 512 elements, cutting VRAM usage from 4096 to 512—an 8× improvement. In more extreme setups, you might see tens or even hundreds of times in compression factor.

Of course, this is the rosy scenario without position encoding. In practice, Transformers commonly use RoPE (Rotary Position Embedding), which modifies how(Q)and(K)are projected. So we’ll first clarify MLA’s fundamental low-rank approach before examining how RoPE fits in. Next, we illustrate MLA’s workflow with an “intelligent photo album” analogy, and then return to RoPE afterward.


4. Understanding MLA as a “Low-Rank Thumbnail” in an Intelligent Photo Album

Even if you appreciate MLA’s formulas, you might still find them abstract. Let’s use a more intuitive metaphor: treat each “Token” as a “photograph,” “multi-head attention” as “filters applied to photos,” and “KV Cache” as “album storage.” This analogy shows how MLA achieves compressed storage and on-demand decompression.

4.1 Photo Storage: A Low-Rank Thumbnail

Imagine every time you snap a photo (process a token (\mathbf{x}_i)), instead of saving the “full-resolution image plus all filters,” you only keep a “small-yet-informative thumbnail”—the latent vector (\mathbf{c}_i). For instance, if the original image resolution is 4096 (like (d=4096) in MLA), the thumbnail might be 512 in size, achieving roughly 1/8 the original.
Mathematically, ci=xiWc,WcR4096×512. \mathbf{c}_i = \mathbf{x}_i \mathbf{W}_c, \quad \mathbf{W}_c \in \mathbb{R}^{4096 \times 512}. This is akin to “downsampling the photo at capture time,” drastically cutting storage overhead.

4.2 Viewing Photos: Real-Time Decompression

When you “view” a photo with a certain filter—corresponding to attention heads that generate Key/Value—MLA does: ki(s)=ciWkc(s),vi(s)=ciWv(s). \mathbf{k}_i^{(s)} = \mathbf{c}_i \mathbf{W}_{kc}^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{c}_i \mathbf{W}_v^{(s)}. So no matter how many filters (heads) exist, you only keep the thumbnail ((\mathbf{c}_i)) rather than multiple versions of the same image. At inference, each filter’s parameters can reconstruct the Key or Value from that thumbnail, resulting in massive storage reduction.

4.3 On-Demand Reconstruction: Merging at Calculation Time

In actual inference, a “merge” step is performed right when computing the dot product:

qt(s)ki(s)=(xtWq(s))(ciWkc(s))xtWmerged(s)ci. \mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top} = (\mathbf{x}_t \mathbf{W}_q^{(s)})(\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top \approx \mathbf{x}_t \mathbf{W}_{\mathrm{merged}}^{(s)} \mathbf{c}_i^\top.

Hence, the filter parameters(\mathbf{W}_{kc}^{(s)})do not have to be stored in the “album” for each picture (head). Only the compact thumbnail (\mathbf{c}_i) remains. This approach omits the repeated Key/Value overhead that standard multi-head attention requires.


5. The RoPE Challenge: Why Add a “Position Sticker”?

Now that we’ve covered MLA’s low-rank approach and on-demand reconstruction, we must address a practical reality. Transformers often rely on RoPE (Rotary Position Embedding) to incorporate positional information into Key/Query. RoPE complicates the straightforward “latent vector only” approach, since each token’s position(i)introduces a rotation matrix (\mathbf{\mathcal{R}}_i) that affects the dot product differently.

5.1 RoPE: Timestamp and GPS Coordinates

Returning to our album analogy, RoPE becomes the “timestamp or GPS location” on each photo. Besides the core visual content ((\mathbf{c}_i)), we retain a small sticker that encodes when and where the photo was taken. If we tried to embed the time data directly into the thumbnail, relative distances (time differences) might be lost. Thus, in MLA, a portion of K/Query dimension remains explicitly multiplied by (\mathbf{\mathcal{R}}_i), preserving relative position even under our low-rank scheme.

5.2 Split Strategy: (\mathbf{c}_i) + a Small RoPE Dimension

In formal terms, to preserve the rotation property (\mathbf{\mathcal{R}}_m \mathbf{\mathcal{R}}n^\top = \mathbf{\mathcal{R}}{m-n}), MLA splits each Key (and similarly Query) into two parts:

ki(s)=(ciWkc(s))compressed portion    (xiWkrRi)positional portion, \mathbf{k}_i^{(s)}=\underbrace{\bigl(\mathbf{c}_i \mathbf{W}_{kc}^{(s)}\bigr)}_{\text{compressed portion}}\;\oplus\;\underbrace{\bigl(\mathbf{x}_i \mathbf{W}_{kr}\,\mathbf{\mathcal{R}}_i\bigr)}_{\text{positional portion}},

so the KV-Cache only grows by a modest “position dimension,” while the main storage remains the latent vector (\mathbf{c}_i). This design deftly fuses low-rank projection with rotary embeddings for relative positions.


6. The Comprehensive Benefits of MLA: Storage Innovation, Flexible Retrieval, and Spatial-Temporal Fidelity

After breaking MLA down into core steps, we can see three major upsides:

  • Storage: Traditional multi-head attention must store K/V for every head. MLA instead uses a latent vector (\mathbf{c}_i) (low-rank) plus a small dimension for RoPE if needed. VRAM can shrink by factors of several times or more.
  • Computation: With on-demand decompression, Key/Value are reconstructed only when necessary—this can be further optimized by merging them into Query. For very long sequences, memory bandwidth is the real bottleneck, so reducing repeated K/V significantly speeds inference.
  • Position: When a model requires relative position (RoPE), MLA can keep a separate “position sticker” dimension. This preserves temporal/spatial information without forcing the entire K space to be stored separately.

7. From an Engineering Angle: Key Considerations

7.1 Balancing VRAM and Speed

If your application involves sequences of thousands or tens of thousands of tokens, MLA’s latent compression helps reduce the KV-Cache drastically and allows a single GPU to process more tokens or bigger batch sizes, thus boosting throughput.

7.2 Tuning RoPE Dimensions

When splitting K into a low-rank zone ((\mathbf{c}_i)) vs. RoPE zone, if the RoPE dimension is too small, extremely long contexts may not get enough positional signal. Conversely, if it’s too large, MLA’s compression advantages diminish. The optimal trade-off typically emerges from empirical experiments.

7.3 Numerical Stability and Precision

Because MLA merges weight matrices at inference time, using BF16/FP16 can introduce small accumulative errors from changing multiply orders. Generally, it’s acceptable. If your application is ultra-sensitive to accuracy, consider higher precision accumulators or partial float32 fallback.


8. Overall Summary and Future Directions

MLA (Multi-Head Latent Attention) is more than merely “low-rank factorization of K/V.” It pushes further by caching only a latent vector (\mathbf{c}_i) in inference, reconstructing multi-head K/V via on-demand decompression and matrix merging—cutting KV-Cache usage drastically. Then, by applying a split strategy for RoPE, MLA retains relative positional information without forcing the entire K/V to remain explicit.

From an engineering perspective, MLA’s VRAM efficiency for long-context LLM inference is a big advantage, potentially expanding the number of tokens handled on a single card or small cluster. Deciding exactly how to partition the dimension between latent vectors and RoPE, however, depends on your task and model scale. This concept can also extend to other positional encodings (ALiBi, NTK Scaling, etc.) or specialized domains.

Regardless, MLA clearly shows a fresh and powerful path for reducing KV-Cache. We’ll likely see more variants of MLA combined with other attention optimizations, helping large models achieve even more performance under strict hardware constraints.


Appendix: Key Formulas and Their “Photo Album” Analogy

Below are the crucial formulas in MLA, alongside how they match the intelligent photo album metaphor:

  1. Low-Rank Projection

ci=xiWc \mathbf{c}_i = \mathbf{x}_i \mathbf{W}_c

\quad\updownarrow\quad

“Store a thumbnail instead of a full photo,” drastically shrinking storage. \text{“Store a thumbnail instead of a full photo,” drastically shrinking storage.}

  1. Dynamic Decompression

ki(s)=ciWkc(s),vi(s)=ciWv(s) \mathbf{k}_i^{(s)} = \mathbf{c}_i \mathbf{W}_{kc}^{(s)}, \quad \mathbf{v}_i^{(s)} = \mathbf{c}_i \mathbf{W}_v^{(s)}

\quad\updownarrow\quad

“Generate filtered views (Key/Value) from the single thumbnail at runtime.” \text{“Generate filtered views (Key/Value) from the single thumbnail at runtime.”}

  1. On-Demand Reconstruction (Merging Matrices)

qt(s)ki(s)=(xtWq(s))(ciWkc(s)) \mathbf{q}_t^{(s)} \mathbf{k}_i^{(s)\top}=(\mathbf{x}_t \mathbf{W}_q^{(s)}) (\mathbf{c}_i \mathbf{W}_{kc}^{(s)})^\top

\quad\updownarrow\quad

“Combine the filter math with the viewing step, further reducing storage.” \text{“Combine the filter math with the viewing step, further reducing storage.”}

  1. RoPE Split

ki(s)=(ciWkc(s))    (xiWkrRi) \mathbf{k}_i^{(s)} = (\mathbf{c}_i \mathbf{W}_{kc}^{(s)}) \;\oplus\; (\mathbf{x}_i \mathbf{W}_{kr}\mathbf{\mathcal{R}}_i)

\quad\updownarrow\quad

“Keep a separate small label for timestamps or GPS, i.e., relative position.” \text{“Keep a separate small label for timestamps or GPS, i.e., relative position.”}

Reading through these steps, you can see how MLA seamlessly progresses from low-rank projection to on-demand Key/Value restoration and finally to coexisting with RoPE for position encoding. As its name implies, MLA both retains the potent expressiveness of multi-head attention while delegating the main K/V storage burden to a smaller latent representation—enabling longer contexts under limited GPU memory.


9. A Minimal Working Example of MLA

To give a more hands-on feel for MLA’s core concepts, here is a minimal, runnable code sample of an MLA-based MoE Transformer, illustrating how “latent variables,” “on-demand reconstruction,” and “RoPE integration” show up in an actual code structure. In the ModelArgs data class, the field attn_impl: Literal["naive", "absorb"] = "absorb" governs whether we store classical K,V in a naive style or rely on MLA’s latent caching (absorb). For the complete functionality, see the DeepSeek Official Repo.

Below is the consolidated example:

import math
from dataclasses import dataclass
from typing import Tuple, Optional, Literal

import torch
import torch.nn.functional as F
from torch import nn

@dataclass
class ModelArgs:
    max_batch_size: int = 8
    max_seq_len: int = 4096 * 4
    vocab_size: int = 102400
    dim: int = 2048
    inter_dim: int = 10944
    moe_inter_dim: int = 1408
    n_layers: int = 2
    n_dense_layers: int = 1
    n_heads: int = 16
    # moe
    n_routed_experts: int = 64
    n_shared_experts: int = 2
    n_activated_experts: int = 6
    n_expert_groups: int = 1
    n_limited_groups: int = 1
    score_func: Literal["softmax", "sigmoid"] = "softmax"
    route_scale: float = 1.
    # mla
    q_lora_rank: int = 0
    kv_lora_rank: int = 512
    qk_nope_head_dim: int = 128
    qk_rope_head_dim: int = 64
    v_head_dim: int = 128
    # rope
    original_seq_len: int = 4096
    rope_theta: float = 10000.0
    rope_factor: float = 40
    beta_fast: int = 32
    beta_slow: int = 1
    mscale: float = 1.
    # kv-cache
    attn_impl: Literal["naive", "absorb"] = "absorb"

@dataclass
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    return F.linear(x, weight, bias)


class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        nn.init.xavier_normal_(self.weight)
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
            nn.init.zeros_(self.bias)
        else:
            self.register_parameter("bias", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return linear(x, self.weight, self.bias)


def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
    dim = args.qk_rope_head_dim
    seqlen = args.max_seq_len
    beta_fast = args.beta_fast
    beta_slow = args.beta_slow
    base = args.rope_theta
    factor = args.rope_factor

    def find_correction_dim(num_rotations, dim, base, max_seq_len):
        return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))

    def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
        low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
        high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
        return max(low, 0), min(high, dim - 1)

    def linear_ramp_factor(min_val, max_val, dim):
        if min_val == max_val:
            max_val += 0.001
        linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)
        ramp_func = torch.clamp(linear_func, 0, 1)
        return ramp_func

    freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
    if seqlen > args.original_seq_len:
        low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
        smooth = 1 - linear_ramp_factor(low, high, dim // 2)
        freqs = freqs / factor * (1 - smooth) + freqs * smooth

    t = torch.arange(seqlen)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    dtype = x.dtype
    x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
    freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
    y = torch.view_as_real(x * freqs_cis).flatten(3)
    return y.to(dtype)


class MLA(nn.Module):
    """
    Multi-Headed Attention Layer (MLA).
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_heads = args.n_heads
        self.n_local_heads = self.n_heads
        self.q_lora_rank = args.q_lora_rank
        self.kv_lora_rank = args.kv_lora_rank
        self.qk_nope_head_dim = args.qk_nope_head_dim
        self.qk_rope_head_dim = args.qk_rope_head_dim
        self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
        self.v_head_dim = args.v_head_dim
        self.attn_impl = args.attn_impl

        # Q
        if self.q_lora_rank == 0:
            self.wq = Linear(self.dim, self.n_heads * self.qk_head_dim)
        else:
            self.wq_a = Linear(self.dim, self.q_lora_rank)
            self.q_norm = RMSNorm(self.q_lora_rank)
            self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)

        # K,V
        self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = Linear(
            self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)
        )

        # Output
        self.wo = Linear(self.n_heads * self.v_head_dim, self.dim)

        self.softmax_scale = self.qk_head_dim ** -0.5
        if args.max_seq_len > args.original_seq_len:
            mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
            self.softmax_scale *= mscale * mscale

        # register different buffer based on the choice of "attn_impl"
        if self.attn_impl == "naive":
            self.register_buffer(
                "k_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim),
                persistent=False
            )
            self.register_buffer(
                "v_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim),
                persistent=False
            )
        else:
            self.register_buffer(
                "kv_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank),
                persistent=False
            )
            self.register_buffer(
                "pe_cache",
                torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim),
                persistent=False
            )


    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        bsz, seqlen, _ = x.size()
        end_pos = start_pos + seqlen

        if self.q_lora_rank == 0:
            q = self.wq(x)
        else:
            q = self.wq_b(self.q_norm(self.wq_a(x)))
        q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        q_pe = apply_rotary_emb(q_pe, freqs_cis)

        kv = self.wkv_a(x)
        kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)

        if self.attn_impl == "naive":
            # naive mode
            q = torch.cat([q_nope, q_pe], dim=-1)
            kv = self.wkv_b(self.kv_norm(kv))  # (bsz, seqlen, n_heads*(qk_nope_head_dim + v_head_dim))
            kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)

            k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)

            # write cache
            self.k_cache[:bsz, start_pos:end_pos] = k
            self.v_cache[:bsz, start_pos:end_pos] = v

            scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale

            if mask is not None:
                scores += mask.unsqueeze(1)

            scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
            out = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])

        else:
            # the absorb mode proposed in MLA
            wkv_b = self.wkv_b.weight 
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)

            self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
            self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)

            q_nope = torch.einsum(
                "bshd,hdc->bshc",
                q_nope,
                wkv_b[:, :self.qk_nope_head_dim] 
            )

            scores = (
                torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
                torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])
            ) * self.softmax_scale

            if mask is not None:
                scores += mask.unsqueeze(1)

            scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)

            out = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
            out = torch.einsum(
                "bshc,hdc->bshd",
                out,
                wkv_b[:, -self.v_head_dim:]
            )

        out = self.wo(out.flatten(2))
        return out


class MLP(nn.Module):

    def __init__(self, dim: int, inter_dim: int):
        super().__init__()
        self.w1 = Linear(dim, inter_dim)
        self.w2 = Linear(inter_dim, dim)
        self.w3 = Linear(dim, inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Expert(nn.Module):
    """
    Expert layer for Mixture-of-Experts (MoE) models.
    """

    def __init__(self, dim: int, inter_dim: int):
        super().__init__()
        self.w1 = Linear(dim, inter_dim)
        self.w2 = Linear(inter_dim, dim)
        self.w3 = Linear(dim, inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Gate(nn.Module):
    """
    Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
    """

    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.topk = args.n_activated_experts
        self.n_groups = args.n_expert_groups
        self.topk_groups = args.n_limited_groups
        self.score_func = args.score_func
        self.route_scale = args.route_scale
        self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
        nn.init.xavier_normal_(self.weight)
        self.bias = nn.Parameter(torch.zeros(args.n_routed_experts)) if self.dim == 7168 else None

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        scores = linear(x, self.weight, self.bias)

        if self.score_func == "softmax":
            scores = scores.softmax(dim=-1, dtype=torch.float32)
        else:
            scores = scores.sigmoid()

        original_scores = scores
        if self.n_groups > 1:
            scores = scores.view(x.size(0), self.n_groups, -1)
            if self.bias is None:
                group_scores = scores.amax(dim=-1)
            else:
                group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
            indices_groups = group_scores.topk(self.topk_groups, dim=-1)[1]
            mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices_groups, True)
            scores = (scores * mask.unsqueeze(-1)).flatten(1)

        indices = torch.topk(scores, self.topk, dim=-1)[1]
        weights = original_scores.gather(1, indices)
        if self.score_func == "sigmoid":
            weights /= weights.sum(dim=-1, keepdim=True)
        weights *= self.route_scale
        return weights.type_as(x), indices


class MoE(nn.Module):
    """
    Mixture-of-Experts (MoE) module.
    """

    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_routed_experts = args.n_routed_experts
        self.n_activated_experts = args.n_activated_experts
        self.gate = Gate(args)

        self.experts = nn.ModuleList([
            Expert(args.dim, args.moe_inter_dim) for _ in range(self.n_routed_experts)
        ])
        self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shape = x.size()
        x = x.view(-1, self.dim)
        weights, indices = self.gate(x)
        y = torch.zeros_like(x)

        counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
        for i in range(self.n_routed_experts):
            if counts[i] == 0:
                continue
            expert = self.experts[i]
            idx, top = torch.where(indices == i)
            y[idx] += expert(x[idx]) * weights[idx, top, None]

        z = self.shared_experts(x)
        return (y + z).view(shape)


class Block(nn.Module):
    """
    Transformer block combining attention and feed-forward layers.
    """

    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.attn = MLA(args)
        self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
        self.attn_norm = RMSNorm(args.dim)
        self.ffn_norm = RMSNorm(args.dim)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor,
                mask: Optional[torch.Tensor]) -> torch.Tensor:
        x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
        x = x + self.ffn(self.ffn_norm(x))
        return x


class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.max_seq_len = args.max_seq_len
        self.embed = torch.nn.Embedding(args.vocab_size, args.dim)
        self.layers = torch.nn.ModuleList()
        for layer_id in range(args.n_layers):
            self.layers.append(Block(layer_id, args))
        self.norm = RMSNorm(args.dim)
        self.head = Linear(args.dim, args.vocab_size)

        self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int = 0):
        seqlen = tokens.size(1)
        h = self.embed(tokens)
        freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)

        h = self.norm(h)[:, -1]
        logits = self.head(h)
        return logits


if __name__ == "__main__":
    torch.manual_seed(0)
    args = ModelArgs()
    x = torch.randint(0, args.vocab_size, (2, 128))
    model = Transformer(args)
    logits = model(x)
    print(logits.shape)  # (batch_size, vocab_size)

Community

Sign up or log in to comment