yutaozhu94 commited on
Commit
32f55c3
1 Parent(s): c8aa697

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +113 -3
README.md CHANGED
@@ -1,3 +1,113 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ # SPRING: Learning Scalable and Pluggable Virtual Tokens for Retrieval-Augmented Large Language Models
6
+
7
+ <p>
8
+ <a href="https://github.com/DaoD/SPRING/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-MIT-blue" alt="license"></a>
9
+ <a href="https://arxiv.org/abs/2405.19670"><img src="https://img.shields.io/badge/Paper-Arxiv-red"></a>
10
+ <a href="https://huggingface.co/yutaozhu94/SPRING"><img src="https://img.shields.io/badge/Embeddings-%F0%9F%A4%97%20Hugging%20Face-8A2BE2"></a>
11
+ </p>
12
+
13
+ **Authors**: Yutao Zhu, Zhaoheng Huang, Zhicheng Dou, and Ji-Rong Wen
14
+
15
+ | Virtual token embeddings file | Backbone Model |
16
+ |:---------------------------------------------------------------------------------|:------------------------------------------------------------------------|
17
+ | mistral.7b.instruct.added_token_embeddings.pt | [Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) |
18
+ | mistral.7b.base.added_token_embeddings.pt | [Mistral-7b-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) |
19
+ | llama2.7b.chat.added_token_embeddings.pt | [LLaMA-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
20
+ | llama2.7b.base.added_token_embeddings.pt | [LLaMA-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-hf) |
21
+ | llama2.13b.chat.added_token_embeddings.pt | [LLaMA-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) |
22
+ | llama2.13b.base.added_token_embeddings.pt | [LLaMA-2-7b-base](https://huggingface.co/meta-llama/Llama-2-13b-hf) |
23
+
24
+ ## News
25
+ - May, 2024: We have released our paper on arXiv. The code and models are preparing and will be released later.
26
+
27
+ ## Introduction
28
+
29
+ Retrieval-augmented generation (RAG) is a promising way to improve large language models (LLMs) for generating more factual, accurate, and up-to-date content. Existing methods either optimize prompts to guide LLMs in leveraging retrieved information or directly fine-tune the LLMs to adapt to RAG scenarios. Although fine-tuning can yield better performance, it often compromises the LLMs' general generation capabilities by modifying their parameters. This limitation poses challenges in practical applications, especially when LLMs are already deployed, as parameter adjustments may affect their original functionality. To address this, we propose a novel method that involves learning scalable and pluggable virtual tokens for RAG. By maintaining the LLMs’ original parameters and fine-tuning only the embeddings of these pluggable tokens, our approach not only enhances LLMs’ performance but also preserves their general generation capacities. Furthermore, we design several training strategies to improve the scalability, flexibility, and generalizability of our method. Comprehensive experiments across nine question-answering tasks demonstrate the superiority of our approach.
30
+
31
+ ## Usage
32
+
33
+ ### Required packages
34
+ ```
35
+ torch 2.0.0
36
+ transformers 4.37.0
37
+ ```
38
+
39
+ ### Load token embeddings
40
+ ```python
41
+ def load_tokens(model, tokenizer, token_embedding_path=""):
42
+ new_tokens_weights = torch.load(token_embedding_path)
43
+ new_tokens_length = new_tokens_weights.shape[0]
44
+
45
+ # expand vocabulary
46
+ new_tokens = [f"[ref{i+1}]" for i in range(new_tokens_length)]
47
+ tokenizer.add_tokens(new_tokens)
48
+
49
+ # get original embedding weight matrix
50
+ embedding_layer = model.get_input_embeddings()
51
+ embedding_weights = embedding_layer.weight
52
+ original_vocab_size, embedding_dim = embedding_weights.shape
53
+
54
+ # create new embedding matrix
55
+ new_vocab_size = original_vocab_size + new_tokens_length
56
+ new_embedding_weights = torch.zeros(new_vocab_size, embedding_dim)
57
+
58
+ # copy original embeddings to the new weights
59
+ new_embedding_weights[:original_vocab_size, :] = embedding_weights
60
+
61
+ # append virtual token embeddings to the new weights
62
+ for token, embedding in zip(new_tokens, new_tokens_weights):
63
+ token_id = tokenizer.convert_tokens_to_ids(token)
64
+ new_embedding_weights[token_id] = embedding
65
+
66
+ # update the embedding table
67
+ # note: we should avoid using the function resize_token_embeddings() because this function will also change the lm_head of the model
68
+ embedding_layer.weight.data = new_embedding_weights
69
+
70
+ # model.resize_token_embeddings(len(tokenizer))
71
+
72
+ return model, tokenizer
73
+
74
+ model_path = "path/to/Mistral-7B-Instruct-v0.1"
75
+ model = AutoModelForCausalLM.from_pretrained(model_path)
76
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
77
+ model, tokenizer = load_tokens(model, tokenizer, token_embedding_path="/path/to/mistral.7b.instruct.added_token_embeddings.pt")
78
+ ```
79
+
80
+ ### Add virtual tokens to the input
81
+ ```python
82
+ # using 50 tokens as an example
83
+ added_tokens = [f" [ref{i}]" for i in range(1, 51)]
84
+ added_tokens = "".join(added_tokens)
85
+ retrieved_results = "..."
86
+ question = "..."
87
+ text = [f"{retrieved_results}{added_tokens}Question: {question}\nAnswer:"]
88
+
89
+ ...
90
+
91
+ outputs = model.generate(...)
92
+
93
+ ```
94
+
95
+
96
+ ## Citation
97
+ Please kindly cite our paper if it helps your research:
98
+ ```BibTex
99
+ @article{SPRING,
100
+ author = {Yutao Zhu and
101
+ Zhaoheng Huang and
102
+ Zhicheng Dou and
103
+ Ji{-}Rong Wen},
104
+ title = {One Token Can Help! Learning Scalable and Pluggable Virtual Tokens for Retrieval-Augmented Large Language Models},
105
+ journal = {CoRR},
106
+ volume = {abs/2405.19670},
107
+ year = {2024},
108
+ url = {https://doi.org/10.48550/arXiv.2405.19670},
109
+ doi = {10.48550/ARXIV.2405.19670},
110
+ eprinttype = {arXiv},
111
+ eprint = {2405.19670}
112
+ }
113
+ ```