Abstract
Transformer tends to overallocate attention to irrelevant context. In this work, we introduce Diff Transformer, which amplifies attention to the relevant context while canceling noise. Specifically, the differential attention mechanism calculates attention scores as the difference between two separate softmax attention maps. The subtraction cancels noise, promoting the emergence of sparse attention patterns. Experimental results on language modeling show that Diff Transformer outperforms Transformer in various settings of scaling up model size and training tokens. More intriguingly, it offers notable advantages in practical applications, such as long-context modeling, key information retrieval, hallucination mitigation, in-context learning, and reduction of activation outliers. By being less distracted by irrelevant context, Diff Transformer can mitigate hallucination in question answering and text summarization. For in-context learning, Diff Transformer not only enhances accuracy but is also more robust to order permutation, which was considered as a chronic robustness issue. The results position Diff Transformer as a highly effective and promising architecture to advance large language models.
Community
Transformer tends to overallocate attention to irrelevant context. In this work, we introduce Diff Transformer, which amplifies attention to the relevant context while canceling noise. Specifically, the differential attention mechanism calculates attention scores as the difference between two separate softmax attention maps. The subtraction cancels noise, promoting the emergence of sparse attention patterns. Experimental results on language modeling show that Diff Transformer outperforms Transformer in various settings of scaling up model size and training tokens. More intriguingly, it offers notable advantages in practical applications, such as long-context modeling, key information retrieval, hallucination mitigation, in-context learning, and reduction of activation outliers. By being less distracted by irrelevant context, Diff Transformer can mitigate hallucination in question answering and text summarization. For in-context learning, Diff Transformer not only enhances accuracy but is also more robust to order permutation, which was considered as a chronic robustness issue. The results position Diff Transformer as a highly effective and promising architecture to advance large language models.
We observe that Diff Transformer allocates less attention scores to attention sinks, i.e., the first few tokens in the sequence.
Specifically, in language modeling task, Diff Transformer allocates less than 5% scores to the BOS token, while Transformer allocates about 25%. For the key information retrieval task, please refer to Figure 1 in the paper. We find that models attend the BOS token more when there is less useful information in the context.
Great stuff. I would love to see comparisons against MรถbiusAttention, which is learns to forget...but this is seems way more computationally efficient.
Thanks for pointing out this paper. We will study into it.
It is a neat approach, but one that comes with a tradeoff, IIUC: doubling the key heads.
I wonder if a different approach without that issue exists. For instance, using max(0, exp(x)-1) instead of exp(x) in the softmax attention formula. That way when the query is orthogonal to the key (or worse), it does not contribute.
In Diff Transformer, we split heads instead of doubling heads. No extra QK projection parameters are introduced. Heads of Q and K are split into two groups and compute in pairs. In a pair they share the same V with dimension 2d. With this design, we match flops and parameter counts with Transformer.
Using max(0, exp(x)-1) might be an approach that solves the problem. We didn't try this because we believe the property of exp() is important to learning.
Great work! Just wonder do you have any idea why two learned attentions tend to cancel noise, rather than canceling signals? For instance, if attention 1 learns S + N_1, and attention 2 learns S + N_2 (where S is signal, N_1, N_2 are different noises), by subtracting these two, the signal S gets canceled while noise becomes N_1 - N_2 which could be more complicated. Is there any reason why the model would not do this instead?
It's a good question. Our observation is that the model knows what signal is and what noise is. Notice that attention_1 and attention_2 are both calculated with learnable parameters, they can "perceive" each other in the training process. Then they can adjust themselves according to each other, to achieve lower loss. The result is that the model chooses to preserve signal and cancel out noise as long as we give it the chance to do so. And for a single softmax, it's difficult for it to learn the same solution, due to its formulation and gradient properties.
This is an automated message from the Librarian Bot. I found the following papers similar to this paper.
The following papers were recommended by the Semantic Scholar API
- FocusLLM: Scaling LLM's Context by Parallel Decoding (2024)
- Selective Attention Improves Transformer (2024)
- Masked Mixers for Language Generation and Retrieval (2024)
- PEAR: Position-Embedding-Agnostic Attention Re-weighting Enhances Retrieval-Augmented Generation with Zero Inference Overhead (2024)
- Mamba Retriever: Utilizing Mamba for Effective and Efficient Dense Retrieval (2024)
Please give a thumbs up to this comment if you found it helpful!
If you want recommendations for any Paper on Hugging Face checkout this Space
You can directly ask Librarian Bot for paper recommendations by tagging it in a comment:
@librarian-bot
recommend
Sure. lambda is multiplied to softmax, where softmax = exp(qk) / Sigma(exp(qk)). Parameters in lambda learns with the same rate as other parameters in the model, therefore lambda should take a similar formulation as softmax. That's why lambda = exp(lambda_q * lambda_k) + lambda_init. Moreover, to enable lambda to learn values smaller than lambda_init, we add the second term, i.e., lambda = exp(lambda_q1 * lambda_k1) - exp(lambda_q2 * lambda_k2) + lambda_init
What kind of hardware was required to train this, and how did the tokens per second output compare with transformers?
No requirements for hardware if you use the naive implementation. If you use flashdiff, refer to FlashAttention repo (https://github.com/Dao-AILab/flash-attention) for hardware and datatype requirements.
Our speed test is performed on Nvidia H100-80GB GPU cards and we calculate throughput (tokens per second). The same cards and environment are used for both Diff and Transformer.
The work looks exciting and I really like the motivation coming from noise cancellation!
I have a few questions -
Won't this model let the post-attention weight (softmax(...) - \lambda * softmax(...)) for some value vectors be negative? Is that a design choice? One explanation does come to mind i.e. wanting to get opposing contributions from some tokens specifically but I am unsure if this is desired.
This recent work (https://arxiv.org/pdf/2410.01104) shows that attention will disperse given a few conditions (see Lemma 2.1, Page 3). Do you think differential attention is any different? If I understand the proposal correctly, I think it still satisfies Lemma 2.1 with some minor modifications in the proof.
Thanks again for your wonderful work!
Yes, there are some negative values in the post-subtraction weight, and that's what we want. The design can expand the representation space of attention weights, which promotes modeling capability. The model is free to allocate positive or negative values to tokens.
If I understand correctly, Diff can break the property in Lemma 2.1. In the paper, Equation 4 points out that values of a single softmax output have a positive lower bound, as input logits can't reach negative infinity. However, by taking the difference of two softmax, the output range includes 0 in it, which means the attention weights is not O(1/n) anymore. This breaks Lemma 2.1. Diff can generate 0 as attention values if it wants and assign it to unwanted context, in the meanwhile, leave almost all attention for key information.
Here's my summary of this paper:
โก๏ธ ๐๐จ๐ฌ๐ญ ๐ข๐ฆ๐ฉ๐จ๐ซ๐ญ๐๐ง๐ญ ๐๐ซ๐๐๐ค๐ญ๐ก๐ซ๐จ๐ฎ๐ ๐ก ๐ญ๐ก๐ข๐ฌ ๐ฆ๐จ๐ง๐ญ๐ก: ๐๐ข๐๐๐๐ซ๐๐ง๐ญ๐ข๐๐ฅ ๐๐ซ๐๐ง๐ฌ๐๐จ๐ซ๐ฆ๐๐ซ ๐ฏ๐๐ฌ๐ญ๐ฅ๐ฒ ๐ข๐ฆ๐ฉ๐ซ๐จ๐ฏ๐๐ฌ ๐๐ญ๐ญ๐๐ง๐ญ๐ข๐จ๐ง โ ๐๐๐ญ๐ญ๐๐ซ ๐ซ๐๐ญ๐ซ๐ข๐๐ฏ๐๐ฅ ๐๐ง๐ ๐๐๐ฐ๐๐ซ ๐ก๐๐ฅ๐ฅ๐ฎ๐๐ข๐ง๐๐ญ๐ข๐จ๐ง๐ฌ!
Thought that self-attention could not be improved anymore?
Microsoft researchers have dropped a novel "differential attention" mechanism that amplifies focus on relevant context while canceling out noise. It sounds like a free lunch, but it does really seem to vastly improve LLM performance!
๐๐ฒ๐ ๐ถ๐ป๐๐ถ๐ด๐ต๐๐:
๐ง Differential attention computes the difference between two separate softmax attention maps, canceling out noise and promoting sparse attention patterns
๐ฅ DIFF Transformer outperforms standard Transformers while using 35-40% fewer parameters or training tokens
๐ Scales well to long contexts up to 64K tokens, leveraging increasing context length more effectively
๐ Dramatically improves key information retrieval, enhancing in-context learning, and possibly reducing risk of hallucinations ๐คฏ
๐ข Reduces activation outliers, potentially enabling lower-bit quantization without performance drop!
โ๏ธ Can be directly implemented using existing FlashAttention kernels
This new architecture could lead much more capable LLMs, with vastly improved strengths in long-context understanding and factual accuracy.
But they didnโt release weights on the Hub: letโs wait for the community to train the first open-weights DiffTransformer! ๐
This paper was a great read. We wrote a summary blog about this paper and a few more like
- TPI LLM
- Differential Transformer
- ARIA
You can find it here. Please give it a read :)
I only have 1 burning question about this paper, is this architecture compatible with the attention mechanism method described in "Selective Attention Improves Transformer"?
Hi, we haven't tried to combine them together. Diff Transformer and Selective Attention are proposed from different views and solve different problems. I believe they are compatible.
The proposed approach sounds intriguing. Thanks for your work!
Can you provide any intuition and/or theoretical justification on why vanilla softmax attention fails to deal with noisy tokens in a proper way? Where are the weaknesses in its structure that prevent ignoring irrelevant tokens in a sequence and concentrating on the essential ones?
Hi, you can refer to a recent paper "softmax is not enough for sharp out-of-distribution" (https://arxiv.org/abs/2410.01104).
In simple terms, 1. Softmax can't produce zero scores due to its definition; 2. Producing near-zero scores needs a wide input range which harms backpropagation of softmax. That's why the model fails to cancel out irrelavant tokens.
Models citing this paper 0
No model linking this paper
Datasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper