phoebeklett commited on
Commit
84a3172
·
verified ·
1 Parent(s): 1b5ad64

Upload model code

Browse files
Files changed (2) hide show
  1. configuration.py +294 -0
  2. modeling.py +1297 -0
configuration.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 HuggingFace Inc. team and MosaicML NLP team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This code has been adapted from Mosaic ML and Huggingface and inherits the above lisence.
16
+ # The original code can be found here:
17
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/configuration_mpt.py
18
+
19
+ """Extended Mind Mpt configuration"""
20
+ from typing import Optional, Union
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class ExtendedMptAttentionConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`ExtendedMptAttention`] class. It is used to instantiate
31
+ attention layers according to the specified arguments, defining the layers architecture. Instantiating a
32
+ configuration with the defaults will yield a similar configuration to that of the MPT
33
+ [mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b) architecture. Most of the arguments are kept for backward
34
+ compatibility with previous MPT models that are hosted on the Hub (previously with `trust_remote_code=True`).
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+ Args:
40
+ attn_type (`str`, *optional*, defaults to `"multihead_attention"`):
41
+ type of attention to use. Options: `"multihead_attention"`, `"multiquery_attention"`.
42
+ attn_pdrop (`float`, *optional*, defaults to 0.0):
43
+ The dropout probability for the attention layers.
44
+ attn_impl (`str`, *optional*, defaults to `"torch"`):
45
+ The attention implementation to use. One of `"torch"`, `"flash"`, or `"triton"`.
46
+ clip_qkv (`float`, *optional*):
47
+ If not `None`, clip the queries, keys, and values in the attention layer to this value.
48
+ softmax_scale (`float`, *optional*, defaults to `None`):
49
+ If not `None`, scale the softmax in the attention layer by this value. If `None`, will default to
50
+ `1/sqrt(hidden_size)`.
51
+ prefix_lm (`bool`, *optional*, defaults to `False`)):
52
+ Whether the model should operate as a Prefix LM. This requires passing an extra `prefix_mask` argument
53
+ which indicates which tokens belong to the prefix. Tokens in the prefix can attend to one another
54
+ bi-directionally. Tokens outside the prefix use causal attention.
55
+ qk_ln (`bool`, *optional*, defaults to `False`):
56
+ Whether to apply layer normalization to the queries and keys in the attention layer.
57
+ attn_uses_sequence_id (`bool`, *optional*, defaults to `False`)):
58
+ Whether to restrict attention to tokens that have the same token_type_ids. When the model is in `train`
59
+ mode, this requires passing an extra *token_type_ids* argument which indicates which sub-sequence each
60
+ token belongs to. Defaults to `False` meaning any provided *token_type_ids* will be ignored.
61
+ alibi (`bool`, *optional*, defaults to `True`):
62
+ Whether or not to use the alibi bias instead of positional embedding.
63
+ alibi_bias_max (`int`, *optional*, defaults to 8):
64
+ The maximum value of the alibi bias.
65
+
66
+ #### Memory Configuration ####
67
+ topk (`int`, *optional*, defaults to `10`):
68
+ Number of external memories for each query token to retrieve and attend to.
69
+ memory_type (`string`, *optional*, defaults to `manual`):
70
+ Whether to store external memories manually or in a vector database.
71
+ memory_device (`string`, *optional*, defaults to `cpu`):
72
+ Specify device to store memory.
73
+ mask_by_sim (`bool`, *optional*, defaults to `True`):
74
+ Whether or not to mask retrieved memories by similarity.
75
+ sim_threshold (`float`, *optional*, defaults to `0.25`):
76
+ Threshold for masking retrieved memories.
77
+ tokenizer_all_special_ids (`list`, *optional*, defaults to `[0, 50278]`):
78
+ Ids for special tokens to remove from memories.
79
+ remove_special_tokens (`bool`, *optional*, defaults to `True`):
80
+ Remove memories that correspond to tokenizer special ids.
81
+ #### Memory Configuration ####
82
+
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ attn_type="multihead_attention",
88
+ attn_pdrop=0,
89
+ attn_impl="torch",
90
+ clip_qkv=None,
91
+ softmax_scale=None,
92
+ prefix_lm=False,
93
+ qk_ln=False,
94
+ attn_uses_sequence_id=False,
95
+ alibi=True,
96
+ alibi_bias_max=8,
97
+ topk=10,
98
+ memory_type="manual",
99
+ memory_device="cpu",
100
+ mask_by_sim=True,
101
+ sim_threshold=0.25,
102
+ tokenizer_all_special_ids=[0, 50278],
103
+ remove_special_ids=False,
104
+ **kwargs,
105
+ ):
106
+ super().__init__(**kwargs)
107
+ self.attn_type = attn_type
108
+ self.attn_pdrop = attn_pdrop
109
+ self.attn_impl = attn_impl
110
+ self.clip_qkv = clip_qkv
111
+ self.softmax_scale = softmax_scale
112
+ self.prefix_lm = prefix_lm
113
+ self.attn_uses_sequence_id = attn_uses_sequence_id
114
+ self.alibi = alibi
115
+ self.qk_ln = qk_ln
116
+ self.alibi_bias_max = alibi_bias_max
117
+ self.topk = topk
118
+ self.memory_type = memory_type
119
+ self.memory_device = memory_device
120
+ self.mask_by_sim = mask_by_sim
121
+ self.sim_threshold = sim_threshold
122
+ self.tokenizer_all_special_ids = tokenizer_all_special_ids
123
+ self.remove_special_ids = remove_special_ids
124
+
125
+ if attn_type not in ["multihead_attention", "multiquery_attention"]:
126
+ raise ValueError(
127
+ f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}"
128
+ )
129
+
130
+ @classmethod
131
+ def from_pretrained(
132
+ cls, pretrained_model_name_or_path, **kwargs
133
+ ) -> "PretrainedConfig":
134
+ cls._set_token_in_kwargs(kwargs)
135
+
136
+ config_dict, kwargs = cls.get_config_dict(
137
+ pretrained_model_name_or_path, **kwargs
138
+ )
139
+
140
+ if config_dict.get("model_type") == "mpt":
141
+ config_dict = config_dict["attn_config"]
142
+
143
+ if (
144
+ "model_type" in config_dict
145
+ and hasattr(cls, "model_type")
146
+ and config_dict["model_type"] != cls.model_type
147
+ ):
148
+ logger.warning(
149
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
150
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
151
+ )
152
+
153
+ return cls.from_dict(config_dict, **kwargs)
154
+
155
+
156
+ class ExtendedMptConfig(PretrainedConfig):
157
+ """
158
+ This is the configuration class to store the configuration of a [`MptModel`]. It is used to instantiate a Mpt model
159
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
160
+ defaults will yield a similar configuration to the Mpt-7b architecture
161
+ [mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b).
162
+
163
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
164
+ documentation from [`PretrainedConfig`] for more information.
165
+
166
+ Args:
167
+ d_model (`int`, *optional*, defaults to 2048):
168
+ Dimensionality of the embeddings and hidden states.
169
+ n_heads (`int`, *optional*, defaults to 16):
170
+ Number of attention heads for each attention layer in the Transformer encoder.
171
+ n_layers (`int`, *optional*, defaults to 24):
172
+ Number of hidden layers in the Transformer encoder.
173
+ expansion_ratio (`int`, *optional*, defaults to 4):
174
+ The ratio of the up/down scale in the MLP.
175
+ max_seq_len (`int`, *optional*, defaults to 2048):
176
+ The maximum sequence length of the model.
177
+ vocab_size (`int`, *optional*, defaults to 50368):
178
+ Vocabulary size of the Mpt model. Defines the maximum number of different tokens that can be represented by
179
+ the `inputs_ids` passed when calling [`MptModel`]. Check [this
180
+ discussion](https://huggingface.co/bigscience/mpt/discussions/120#633d28389addb8530b406c2a) on how the
181
+ `vocab_size` has been defined.
182
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
183
+ The dropout probability applied to the attention output before combining with residual.
184
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
185
+ The epsilon to use in the layer normalization layers.
186
+ emb_pdrop (`float`, *optional*, defaults to 0.1):
187
+ The dropout probability for the embedding layer.
188
+ learned_pos_emb (`bool`, *optional*, defaults to `False`):
189
+ Whether to use learned positional embeddings.
190
+ attn_config (`dict`, *optional*):
191
+ A dictionary used to configure the model's attention module.
192
+ init_device (`str`, *optional*):
193
+ The device to use for parameter initialization. Defined for backward compatibility
194
+ logit_scale (`float`, *optional*):
195
+ If not None, scale the logits by this value.
196
+ no_bias (`bool`, *optional*, defaults to `True`):
197
+ Whether to use bias in all linear layers.
198
+ verbose (`int`, *optional*, defaults to 0):
199
+ The verbosity level to use for logging. Used in the previous versions of MPT models for logging. This
200
+ argument is deprecated.
201
+ embedding_fraction (`float`, *optional*, defaults to 1.0):
202
+ The fraction to scale the gradients of the embedding layer by.
203
+ norm_type (`str`, *optional*, defaults to `"low_precision_layernorm"`):
204
+ Type of layer norm to use. All MPT models uses the same layer norm implementation. Defined for backward
205
+ compatibility.
206
+ use_cache (`bool`, *optional*, defaults to `True`):
207
+ Whether or not the model should return the last key/values attentions (not used by all models).
208
+ initializer_range (`float`, *optional*, defaults to 0.02):
209
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
210
+
211
+ #### Memory Configuration ####
212
+ use_external_mind (`bool`, *optional*, defaults to `True`):
213
+ Whether to attend to external memories.
214
+ use_external_mind_by_layer (`List[bool]`, *optional*, defaults to List[`True`, ..., `True`]):
215
+ Whether to attend to external memories, on each decoder layer.
216
+ #### Memory Configuration ####
217
+
218
+ Example:
219
+
220
+ ```python
221
+ >>> from transformers import MptConfig, MptModel
222
+
223
+ >>> # Initializing a Mpt configuration
224
+ >>> configuration = MptConfig()
225
+
226
+ >>> # Initializing a model (with random weights) from the configuration
227
+ >>> model = MptModel(configuration)
228
+
229
+ >>> # Accessing the model configuration
230
+ >>> configuration = model.config
231
+ ```
232
+ """
233
+
234
+ model_type = "extended-mpt"
235
+ attribute_map = {
236
+ "num_attention_heads": "n_heads",
237
+ "hidden_size": "d_model",
238
+ "num_hidden_layers": "n_layers",
239
+ }
240
+
241
+ def __init__(
242
+ self,
243
+ d_model: int = 4096,
244
+ n_heads: int = 32,
245
+ n_layers: int = 32,
246
+ expansion_ratio: int = 4,
247
+ max_seq_len_inference: int = 2048,
248
+ max_seq_len_train: int = 2048,
249
+ vocab_size: int = 50432,
250
+ resid_pdrop: float = 0.0,
251
+ layer_norm_epsilon: float = 1e-5,
252
+ emb_pdrop: float = 0.0,
253
+ learned_pos_emb: bool = True,
254
+ attn_config: ExtendedMptAttentionConfig = None,
255
+ init_device: str = "cpu",
256
+ logit_scale: Optional[Union[float, str]] = None,
257
+ no_bias: bool = True,
258
+ verbose: int = 0,
259
+ embedding_fraction: float = 1.0,
260
+ norm_type: str = "low_precision_layernorm",
261
+ use_cache: bool = False,
262
+ initializer_range=0.02,
263
+ use_external_mind: bool = True,
264
+ use_external_mind_by_layer: list[bool] = [True for _ in range(32)],
265
+ **kwargs,
266
+ ):
267
+ if attn_config is None:
268
+ self.attn_config = ExtendedMptAttentionConfig()
269
+ elif not isinstance(attn_config, ExtendedMptAttentionConfig):
270
+ self.attn_config = ExtendedMptAttentionConfig(**attn_config)
271
+ else:
272
+ self.attn_config = attn_config
273
+ self.d_model = d_model
274
+ self.n_heads = n_heads
275
+ self.n_layers = n_layers
276
+ self.expansion_ratio = expansion_ratio
277
+ self.max_seq_len = max_seq_len_inference
278
+ self.max_seq_len_train = max_seq_len_train
279
+ self.vocab_size = vocab_size
280
+ self.resid_pdrop = resid_pdrop
281
+ self.emb_pdrop = emb_pdrop
282
+ self.learned_pos_emb = learned_pos_emb
283
+ self.init_device = init_device
284
+ self.logit_scale = logit_scale
285
+ self.no_bias = no_bias
286
+ self.verbose = verbose
287
+ self.embedding_fraction = embedding_fraction
288
+ self.norm_type = norm_type
289
+ self.layer_norm_epsilon = layer_norm_epsilon
290
+ self.use_cache = use_cache
291
+ self.initializer_range = initializer_range
292
+ self.use_external_mind = use_external_mind
293
+ self.use_external_mind_by_layer = use_external_mind_by_layer
294
+ super().__init__(**kwargs)
modeling.py ADDED
@@ -0,0 +1,1297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 HuggingFace Inc. team and MosaicML NLP team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This code has been adapted from Mosaic ML and Huggingface and inherits the above lisence.
16
+ # The original code can be found here:
17
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
18
+ # We annotate the edited code below with 'EM' comments to indicate where we have made changes.
19
+ """PyTorch MPT model."""
20
+
21
+ import math
22
+ from typing import Optional, Tuple, Union
23
+
24
+ import faiss
25
+ import numpy as np
26
+ import torch
27
+ import torch.utils.checkpoint
28
+ from einops import rearrange
29
+ from torch import nn
30
+ from torch.linalg import vector_norm
31
+ from torch.nn import CrossEntropyLoss, LayerNorm
32
+ from torch.nn import functional as F
33
+ from transformers.file_utils import (
34
+ add_code_sample_docstrings,
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ )
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPastAndCrossAttentions,
40
+ CausalLMOutputWithCrossAttentions,
41
+ )
42
+ from transformers.modeling_utils import PreTrainedModel
43
+ from transformers.utils import logging
44
+
45
+ from .configuration import ExtendedMptConfig
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ _CHECKPOINT_FOR_DOC = "mosaicml/mpt-7b"
50
+ _CONFIG_FOR_DOC = "MptConfig"
51
+
52
+
53
+ # Copied from transformers.models.bloom.modeling_bloom._make_causal_mask
54
+ def _make_causal_mask(
55
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
56
+ ) -> torch.BoolTensor:
57
+ """
58
+ Make causal mask used for self-attention.
59
+ """
60
+ batch_size, target_length = input_ids_shape
61
+ mask = torch.empty(
62
+ (target_length, target_length + past_key_values_length),
63
+ dtype=torch.bool,
64
+ device=device,
65
+ )
66
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
67
+ seq_ids = torch.arange(target_length, device=device)
68
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
69
+
70
+ if past_key_values_length > 0:
71
+ mask[:, :past_key_values_length] = False
72
+
73
+ expanded_mask = mask[None, None, :, :].expand(
74
+ batch_size, 1, target_length, target_length + past_key_values_length
75
+ )
76
+ return expanded_mask
77
+
78
+
79
+ # Copied from transformers.models.bloom.modeling_bloom._expand_mask
80
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
81
+ """
82
+ Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
83
+ """
84
+ batch_size, src_length = mask.shape
85
+ tgt_length = tgt_length if tgt_length is not None else src_length
86
+
87
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
88
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
89
+
90
+
91
+ def build_mpt_alibi_tensor(
92
+ num_heads,
93
+ sequence_length,
94
+ sequence_length_with_past,
95
+ alibi_bias_max=8,
96
+ device=None,
97
+ for_ae=False,
98
+ topk=None,
99
+ ):
100
+ r"""
101
+ Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
102
+ relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
103
+ the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
104
+ https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
105
+ """
106
+ if not for_ae:
107
+ alibi = torch.arange(
108
+ 1 - sequence_length, 1, dtype=torch.int32, device=device
109
+ ).view(1, 1, 1, sequence_length)
110
+ else: # EM: All memory tokens get same bias
111
+ alibi = (
112
+ torch.tensor(-sequence_length_with_past, dtype=torch.int32, device=device)
113
+ .repeat(sequence_length * topk)
114
+ .view(1, 1, 1, sequence_length * topk)
115
+ )
116
+ num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads))
117
+
118
+ base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.float32, device=device)
119
+ base = base * (alibi_bias_max / num_heads_power_of_2)
120
+
121
+ slopes = 1.0 / torch.pow(2, base)
122
+ slopes = slopes.view(1, num_heads, 1, 1)
123
+
124
+ if num_heads_power_of_2 != num_heads:
125
+ slopes = torch.concat([slopes[1::2], slopes[::2]])[:num_heads]
126
+
127
+ alibi = alibi * slopes
128
+ return alibi.squeeze(0)
129
+
130
+
131
+ class ExtendedMptAttention(nn.Module):
132
+ """Multi-head self attention.
133
+ Using torch or triton attention implemetation enables user to also use additive bias.
134
+ """
135
+
136
+ def __init__(self, config: ExtendedMptConfig):
137
+ super().__init__()
138
+ self.hidden_size = config.hidden_size
139
+ self.n_heads = config.n_heads
140
+ self.n_layers = config.n_layers
141
+ self.head_dim = self.hidden_size // self.n_heads
142
+ self.softmax_scale = config.attn_config.softmax_scale
143
+ if self.softmax_scale is None:
144
+ self.softmax_scale = 1 / math.sqrt(self.hidden_size / self.n_heads)
145
+
146
+ self.attn_dropout_p = config.attn_config.attn_pdrop
147
+ self.Wqkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
148
+ self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
149
+
150
+ def forward(
151
+ self,
152
+ hidden_states: torch.Tensor,
153
+ position_bias: torch.Tensor,
154
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
155
+ attention_mask: Optional[torch.Tensor] = None,
156
+ long_range_past_key_value=None,
157
+ topk=None,
158
+ faiss_indexes=None,
159
+ mask_by_sim=None,
160
+ sim_threshold=None,
161
+ position_bias_ae=None,
162
+ current_layer=None,
163
+ output_retrieved_memory_idx=False,
164
+ ):
165
+ batch_size, seq_length = hidden_states.shape[:2]
166
+
167
+ mixed_qkv = self.Wqkv(hidden_states)
168
+ query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2)
169
+ query_states = query_states.reshape(
170
+ batch_size, seq_length, self.n_heads, self.head_dim
171
+ ).transpose(1, 2)
172
+ key_states = key_states.reshape(
173
+ batch_size, seq_length, self.n_heads, self.head_dim
174
+ ).transpose(1, 2)
175
+ value_states = value_states.reshape(
176
+ batch_size, seq_length, self.n_heads, self.head_dim
177
+ ).transpose(1, 2)
178
+
179
+ if past_key_value is not None:
180
+ if len(past_key_value) != 0:
181
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
182
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
183
+ past_key_value = (key_states, value_states)
184
+ bsz, nh, s_q, d = query_states.shape
185
+
186
+ attention_scores = (
187
+ torch.matmul(query_states, key_states.transpose(-1, -2))
188
+ * self.softmax_scale
189
+ )
190
+ key_length = key_states.shape[-2]
191
+ query_length = (
192
+ seq_length
193
+ if past_key_value is None
194
+ else seq_length + past_key_value[0].shape[2]
195
+ )
196
+ if position_bias is not None:
197
+ if len(position_bias.shape) != 3:
198
+ raise ValueError(
199
+ f"Expecting position_bias shape to be 3 dimensions, got {len(position_bias.shape)}"
200
+ )
201
+
202
+ position_bias_query_index = max(0, position_bias.size(1) - query_length)
203
+ position_bias_key_index = max(0, position_bias.size(2) - key_length)
204
+
205
+ position_bias = position_bias[
206
+ :, position_bias_query_index:, position_bias_key_index:
207
+ ]
208
+
209
+ attention_scores = attention_scores + position_bias
210
+
211
+ # EM: Retrieve memories from cache or faiss indexes
212
+ if long_range_past_key_value is not None or faiss_indexes is not None:
213
+ if long_range_past_key_value is not None: # Manual store
214
+ k_cache, v_cache = long_range_past_key_value
215
+ s_cache = k_cache.size(-2)
216
+
217
+ k_cache = k_cache.to(key_states.device)
218
+ v_cache = v_cache.to(key_states.device)
219
+
220
+ # Normalize query and key vectors
221
+ q_n = query_states / vector_norm(
222
+ query_states, ord=2, dim=-1, keepdim=True
223
+ )
224
+ k_n = k_cache / vector_norm(k_cache, ord=2, dim=-1, keepdim=True)
225
+ sim = q_n.matmul(k_n.transpose(-1, -2))
226
+ if s_cache < topk: # number of tokens in cache < topk
227
+ topk = s_cache
228
+ val, idx = torch.topk(sim, k=topk, dim=-1) # Retrieve topk memories
229
+
230
+ reshaped_idx = idx.reshape(bsz, nh, s_q * topk)
231
+ selected_k = k_cache.gather(
232
+ dim=-2, index=reshaped_idx.unsqueeze(-1).expand(-1, -1, -1, d)
233
+ )
234
+ selected_v = v_cache.gather(
235
+ dim=-2, index=reshaped_idx.unsqueeze(-1).expand(-1, -1, -1, d)
236
+ )
237
+
238
+ elif faiss_indexes is not None: # FAISS indexes
239
+ kn_index, kv_index = faiss_indexes
240
+ q_n = query_states / vector_norm(
241
+ query_states, ord=2, dim=-1, keepdim=True
242
+ )
243
+ # One-hot encoding for layer, head to only retrieve memories from the same layer, head
244
+ one_hot_encodings = (
245
+ F.one_hot(
246
+ torch.arange(0, nh * self.n_layers, device=query_states.device)
247
+ )
248
+ * 10
249
+ )
250
+ q_n = torch.concat(
251
+ [
252
+ rearrange(q_n, "b h s d -> b (h s) d", h=nh),
253
+ one_hot_encodings[nh * current_layer : nh * (current_layer + 1)]
254
+ .unsqueeze(0)
255
+ .repeat_interleave(repeats=query_states.size(-2), dim=-2),
256
+ ],
257
+ dim=-1,
258
+ ).squeeze()
259
+
260
+ if kn_index.ntotal / (nh * self.n_layers) < topk:
261
+ topk = int(kn_index.ntotal / (nh * self.n_layers))
262
+
263
+ val, idx = kn_index.search(q_n.to("cpu").detach().numpy(), k=topk)
264
+ val = torch.tensor(val - 100).reshape(bsz, nh, s_q, topk) #Similarity includes scale factor from one-hot encoding
265
+ reshaped_idx = torch.tensor(
266
+ idx % (kn_index.ntotal / (nh * self.n_layers))
267
+ ).reshape(bsz, nh, s_q * topk)
268
+
269
+ # Retrieve tensors
270
+ selected_k = rearrange(
271
+ torch.tensor(kv_index.reconstruct_batch(idx.flatten()))[:, :d],
272
+ "(h s) d -> 1 h s d",
273
+ h=nh,
274
+ ).to(query_states.device)
275
+ selected_v = rearrange(
276
+ torch.tensor(kv_index.reconstruct_batch(idx.flatten()))[:, d:],
277
+ "(h s) d -> 1 h s d",
278
+ h=nh,
279
+ ).to(query_states.device)
280
+
281
+ selected_key_length = selected_k.size(-2)
282
+ key_length += selected_key_length
283
+ attention_scores_cache = (
284
+ query_states.matmul(selected_k.transpose(-1, -2)) * self.softmax_scale
285
+ )
286
+ # EM: Mask by similarity
287
+ if mask_by_sim:
288
+ sim_mask = (
289
+ rearrange(~(val > sim_threshold).bool(), "b h s i -> b h (s i)")
290
+ .unsqueeze(-2)
291
+ .expand(-1, -1, s_q, -1)
292
+ ).to(query_states.device)
293
+
294
+ attention_scores_cache = attention_scores_cache.masked_fill(
295
+ sim_mask, torch.finfo(query_states.dtype).min
296
+ )
297
+
298
+ # EM: Add position bias to cache
299
+ if position_bias_ae is not None:
300
+ if len(position_bias_ae.shape) != 3:
301
+ raise ValueError(
302
+ f"Expecting position_bias shape to be 3 dimensions, got {len(position_bias_ae.shape)}"
303
+ )
304
+
305
+ position_bias_query_index = max(
306
+ 0, position_bias_ae.size(1) - query_length
307
+ )
308
+ position_bias_key_index = max(
309
+ 0, position_bias_ae.size(2) - selected_key_length
310
+ )
311
+
312
+ position_bias_ae = position_bias_ae[
313
+ :, position_bias_query_index:, position_bias_key_index:
314
+ ]
315
+
316
+ attention_scores_cache = attention_scores_cache + position_bias_ae
317
+
318
+ # EM: Concatenate cache and current attention weights, values
319
+ attention_scores = torch.cat(
320
+ [attention_scores_cache, attention_scores], dim=-1
321
+ ) # Concat attention scores, values
322
+ value_states = torch.cat([selected_v, value_states], dim=-2)
323
+
324
+ # EM: Create mask for external memories, queries only attend to their own memories
325
+ def _create_external_memories_mask(k, s_q, device):
326
+ mask = torch.zeros(s_q, s_q * k, device=device, dtype=torch.bool)
327
+ for i in range(s_q):
328
+ mask[i, i * k : (i + 1) * k] = 1
329
+ return ~mask
330
+
331
+ if attention_mask is not None:
332
+ # EM: Concatenate attention mask with external memories mask
333
+ if long_range_past_key_value is not None or faiss_indexes is not None:
334
+ mask = _create_external_memories_mask(
335
+ k=topk, s_q=s_q, device=attention_scores.device
336
+ )
337
+ attention_mask = attention_mask.squeeze(dim=0).squeeze(dim=0)
338
+ attention_mask = torch.cat([mask, attention_mask], dim=1)
339
+ attention_scores = attention_scores.masked_fill(
340
+ attention_mask, torch.finfo(query_states.dtype).min
341
+ )
342
+
343
+ # (batch_size, n_heads, seq_length, key_length)
344
+ attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).to(
345
+ value_states.dtype
346
+ )
347
+ attn_weights = nn.functional.dropout(
348
+ attn_weights, p=self.attn_dropout_p, training=self.training
349
+ )
350
+
351
+ context_states = torch.matmul(attn_weights, value_states)
352
+ context_states = (
353
+ context_states.permute(0, 2, 1, 3)
354
+ .contiguous()
355
+ .view(batch_size, seq_length, -1)
356
+ )
357
+ attn_output = self.out_proj(context_states)
358
+
359
+ if not output_retrieved_memory_idx:
360
+ reshaped_idx = None
361
+
362
+ return attn_output, attn_weights, past_key_value, reshaped_idx
363
+
364
+
365
+ class MptMLP(nn.Module):
366
+ def __init__(self, config: ExtendedMptConfig):
367
+ super().__init__()
368
+ hidden_size = config.hidden_size
369
+
370
+ self.up_proj = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
371
+ self.act = nn.GELU(approximate="none")
372
+ self.down_proj = nn.Linear(4 * hidden_size, hidden_size, bias=False)
373
+ self.hidden_dropout = config.attn_config.attn_pdrop
374
+
375
+ def forward(
376
+ self, hidden_states: torch.Tensor, residual: torch.Tensor
377
+ ) -> torch.Tensor:
378
+ hidden_states = self.act(self.up_proj(hidden_states))
379
+
380
+ intermediate_output = self.down_proj(hidden_states)
381
+
382
+ output = F.dropout(
383
+ intermediate_output, p=self.hidden_dropout, training=self.training
384
+ )
385
+ output = output + residual
386
+
387
+ return output
388
+
389
+
390
+ class MptBlock(nn.Module):
391
+ """MPTBlock"""
392
+
393
+ def __init__(self, config: ExtendedMptConfig):
394
+ super().__init__()
395
+ hidden_size = config.hidden_size
396
+
397
+ self.norm_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
398
+ # backward compatibility with weights on the Hub
399
+ self.norm_1.bias = None
400
+
401
+ self.num_heads = config.n_heads
402
+ self.attn = ExtendedMptAttention(config)
403
+
404
+ self.norm_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
405
+ # backward compatibility with weights on the Hub
406
+ self.norm_2.bias = None
407
+
408
+ self.ffn = MptMLP(config)
409
+
410
+ self.dropout_rate = config.attn_config.attn_pdrop
411
+ self.resid_attn_dropout = nn.Dropout(self.dropout_rate)
412
+
413
+ def forward(
414
+ self,
415
+ hidden_states: torch.Tensor,
416
+ position_bias: torch.Tensor,
417
+ attention_mask: torch.Tensor,
418
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
419
+ use_cache: bool = False,
420
+ output_attentions: bool = False,
421
+ output_retrieved_memory_idx: bool = False,
422
+ topk: int = None,
423
+ long_range_past_key_value: Optional[Tuple[torch.Tensor]] = None,
424
+ faiss_indexes: Tuple = None,
425
+ position_bias_ae=None,
426
+ current_layer: int = None,
427
+ mask_by_sim: bool = False,
428
+ sim_threshold: float = None,
429
+ ):
430
+ # hidden_states: [batch_size, seq_length, hidden_size]
431
+ # Layer norm at the beginning of the transformer layer.
432
+ layernorm_output = self.norm_1(hidden_states)
433
+
434
+ residual = hidden_states
435
+
436
+ # Self attention.
437
+ attn_outputs, attn_weights, past_key_value, reshaped_idx = self.attn(
438
+ layernorm_output,
439
+ position_bias=position_bias,
440
+ attention_mask=attention_mask,
441
+ past_key_value=layer_past,
442
+ long_range_past_key_value=long_range_past_key_value,
443
+ topk=topk,
444
+ faiss_indexes=faiss_indexes,
445
+ position_bias_ae=position_bias_ae,
446
+ current_layer=current_layer,
447
+ mask_by_sim=mask_by_sim,
448
+ sim_threshold=sim_threshold,
449
+ output_retrieved_memory_idx=output_retrieved_memory_idx,
450
+ )
451
+
452
+ hidden_states = self.resid_attn_dropout(attn_outputs) + residual
453
+
454
+ layernorm_output = self.norm_2(hidden_states)
455
+
456
+ # Get residual
457
+ residual = hidden_states
458
+
459
+ # MLP.
460
+ output = self.ffn(layernorm_output, residual)
461
+ outputs = (output,)
462
+
463
+ if use_cache:
464
+ outputs += (past_key_value,)
465
+
466
+ if output_attentions:
467
+ outputs += (attn_weights,)
468
+ if output_retrieved_memory_idx:
469
+ outputs += (reshaped_idx,)
470
+
471
+ return outputs # hidden_states, present, attentions
472
+
473
+
474
+ class MptPreTrainedModel(PreTrainedModel):
475
+ """MPT Pretrained Model"""
476
+
477
+ config_class = ExtendedMptConfig
478
+ base_model_prefix = "transformer"
479
+ supports_gradient_checkpointing = True
480
+ _no_split_modules = ["MptBlock"]
481
+ _keys_to_ignore_on_load_missing = [r"lm_head.*."]
482
+
483
+ def __init__(self, *inputs, **kwargs):
484
+ super().__init__(*inputs, **kwargs)
485
+
486
+ def _init_weights(self, module: nn.Module):
487
+ """Initialize the weights."""
488
+ if isinstance(module, nn.Linear):
489
+ # Slightly different from the TF version which uses truncated_normal for initialization
490
+ # cf https://github.com/pytorch/pytorch/pull/5617
491
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
492
+ if module.bias is not None:
493
+ module.bias.data.zero_()
494
+ elif isinstance(module, nn.Embedding):
495
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
496
+ if module.padding_idx is not None:
497
+ module.weight.data[module.padding_idx].zero_()
498
+ elif isinstance(module, LayerNorm):
499
+ if module.bias is not None:
500
+ module.bias.data.zero_()
501
+ module.weight.data.fill_(1.0)
502
+
503
+ def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
504
+ if isinstance(module, ExtendedMptConfig):
505
+ module.gradient_checkpointing = value
506
+
507
+ @staticmethod
508
+ def _convert_to_mpt_cache(
509
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
510
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
511
+ """
512
+ Converts the cache to the format expected by Mpt, i.e. to tuple(tuple([batch_size * num_heads, ...]))
513
+ """
514
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
515
+ batch_size_times_num_heads = batch_size * num_heads
516
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
517
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
518
+ return tuple(
519
+ (
520
+ layer_past[0].reshape(batch_size_times_num_heads, head_dim, seq_length),
521
+ layer_past[1].reshape(batch_size_times_num_heads, seq_length, head_dim),
522
+ )
523
+ for layer_past in past_key_value
524
+ )
525
+
526
+
527
+ MPT_START_DOCSTRING = r"""
528
+
529
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
530
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
531
+
532
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
533
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
534
+ and behavior.
535
+
536
+ Parameters:
537
+ config ([`ExtendedMptConfig`]): Model configuration class with all the parameters of the model.
538
+ Initializing with a config file does not load the weights associated with the model, only the
539
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
540
+ """
541
+
542
+ MPT_INPUTS_DOCSTRING = r"""
543
+ Args:
544
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
545
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
546
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
547
+
548
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
549
+ `input_ids`.
550
+
551
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
552
+ [`PreTrainedTokenizer.__call__`] for details.
553
+
554
+ [What are input IDs?](../glossary#input-ids)
555
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
556
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
557
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
558
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
559
+
560
+ Each element of `past_key_values` is a tuple (past_key, past_value):
561
+ - past_key: [batch_size * num_heads, head_dim, kv_length]
562
+ - past_value: [batch_size * num_heads, kv_length, head_dim]
563
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
564
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
565
+
566
+ - 1 for tokens that are **not masked**,
567
+ - 0 for tokens that are **masked**.
568
+
569
+ [What are attention masks?](../glossary#attention-mask)
570
+
571
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
572
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
573
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
574
+ model's internal embedding lookup matrix.
575
+
576
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
577
+ `past_key_values`).
578
+ use_cache (`bool`, *optional*):
579
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
580
+ `past_key_values`).
581
+ output_attentions (`bool`, *optional*):
582
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
583
+ tensors for more detail.
584
+ output_hidden_states (`bool`, *optional*):
585
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
586
+ more detail.
587
+ return_dict (`bool`, *optional*):
588
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
589
+ use_external_mind (`bool`, *optional*, defaults to `True`):
590
+ Whether to attend to external memories.
591
+ long_range_past_key_values (`List[Tuple[torch.FloatTensor]]`, *optional*, defaults to None):
592
+ Manual store for memories.
593
+ faiss_indexes (`Tuple[faiss.swigfaiss_avx2.IndexFlatIP]`, *optional*, defaults to None):
594
+ Vector store for memories.
595
+ topk (`int`, *optional*, defaults to `10`):
596
+ Number of external memories for each query token to retrieve and attend to.
597
+ """
598
+
599
+
600
+ @add_start_docstrings(
601
+ "The bare Mpt Model transformer outputting raw hidden-states without any specific head on top.",
602
+ MPT_START_DOCSTRING,
603
+ )
604
+ class ExtendedMptModel(MptPreTrainedModel):
605
+ """Extended MPT Model"""
606
+
607
+ def __init__(self, config: ExtendedMptConfig):
608
+ super().__init__(config)
609
+
610
+ self.hidden_size = config.hidden_size
611
+ self.num_heads = config.n_heads
612
+
613
+ # Embedding + LN Embedding
614
+ self.wte = nn.Embedding(config.vocab_size, self.hidden_size)
615
+
616
+ # Transformer blocks
617
+ self.blocks = nn.ModuleList([MptBlock(config) for _ in range(config.n_layers)])
618
+
619
+ # Final Layer Norm
620
+ self.norm_f = LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
621
+ # backward compatibility with weights on the Hub
622
+ self.norm_f.bias = None
623
+
624
+ self.gradient_checkpointing = False
625
+
626
+ # Initialize weights and apply final processing
627
+ self.post_init()
628
+
629
+ self.mask_by_sim = config.attn_config.mask_by_sim
630
+ self.sim_threshold = config.attn_config.sim_threshold
631
+ self.topk = config.attn_config.topk
632
+ self.use_external_mind = config.use_external_mind
633
+ self.use_external_mind_by_layer = config.use_external_mind_by_layer
634
+
635
+ def get_input_embeddings(self):
636
+ return self.wte
637
+
638
+ def build_mpt_alibi_tensor(
639
+ self,
640
+ num_heads,
641
+ sequence_length,
642
+ sequence_length_with_past,
643
+ alibi_bias_max=8,
644
+ device=None,
645
+ for_ae=None,
646
+ topk=None,
647
+ ):
648
+ return build_mpt_alibi_tensor(
649
+ num_heads,
650
+ sequence_length,
651
+ sequence_length_with_past,
652
+ alibi_bias_max,
653
+ device,
654
+ for_ae=for_ae,
655
+ topk=topk,
656
+ )
657
+
658
+ def _prepare_attn_mask(
659
+ self,
660
+ attention_mask: torch.Tensor,
661
+ input_shape: Tuple[int, int],
662
+ past_key_values_length: int,
663
+ ) -> torch.BoolTensor:
664
+ # create causal mask
665
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
666
+ if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
667
+ raise ValueError(
668
+ "Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
669
+ f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
670
+ f" {past_key_values_length}."
671
+ )
672
+ combined_attention_mask = None
673
+ device = attention_mask.device
674
+ _, src_length = input_shape
675
+
676
+ if src_length > 1:
677
+ combined_attention_mask = _make_causal_mask(
678
+ input_shape,
679
+ device=device,
680
+ past_key_values_length=past_key_values_length,
681
+ )
682
+
683
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
684
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
685
+ combined_attention_mask = (
686
+ expanded_attn_mask
687
+ if combined_attention_mask is None
688
+ else expanded_attn_mask | combined_attention_mask
689
+ )
690
+
691
+ return combined_attention_mask
692
+
693
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
694
+ self.wte = new_embeddings
695
+
696
+ @add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING)
697
+ @add_code_sample_docstrings(
698
+ checkpoint=_CHECKPOINT_FOR_DOC,
699
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
700
+ config_class=_CONFIG_FOR_DOC,
701
+ )
702
+ def forward(
703
+ self,
704
+ input_ids: Optional[torch.LongTensor] = None,
705
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
706
+ attention_mask: Optional[torch.Tensor] = None,
707
+ inputs_embeds: Optional[torch.LongTensor] = None,
708
+ use_cache: Optional[bool] = None,
709
+ output_attentions: Optional[bool] = None,
710
+ output_hidden_states: Optional[bool] = None,
711
+ output_retrieved_memory_idx: Optional[bool] = None,
712
+ return_dict: Optional[bool] = None,
713
+ use_external_mind: Optional[bool] = None,
714
+ long_range_past_key_values: Optional[list[Tuple[torch.FloatTensor]]] = None,
715
+ faiss_indexes: Tuple = None,
716
+ topk: int = None,
717
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
718
+ output_attentions = (
719
+ output_attentions
720
+ if output_attentions is not None
721
+ else self.config.output_attentions
722
+ )
723
+ output_retrieved_memory_idx = (
724
+ output_retrieved_memory_idx
725
+ if output_retrieved_memory_idx is not None
726
+ else False
727
+ )
728
+ output_hidden_states = (
729
+ output_hidden_states
730
+ if output_hidden_states is not None
731
+ else self.config.output_hidden_states
732
+ )
733
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
734
+ return_dict = (
735
+ return_dict if return_dict is not None else self.config.use_return_dict
736
+ )
737
+ use_external_mind = (
738
+ use_external_mind
739
+ if use_external_mind is not None
740
+ else self.use_external_mind
741
+ )
742
+ topk = topk if topk is not None else self.topk
743
+
744
+ if input_ids is not None and inputs_embeds is not None:
745
+ raise ValueError(
746
+ "You cannot specify both input_ids and inputs_embeds at the same time"
747
+ )
748
+ elif input_ids is not None:
749
+ batch_size, seq_length = input_ids.shape
750
+ elif inputs_embeds is not None:
751
+ batch_size, seq_length, _ = inputs_embeds.shape
752
+ else:
753
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
754
+
755
+ if past_key_values is None:
756
+ past_key_values = tuple([None] * len(self.blocks))
757
+
758
+ if inputs_embeds is None:
759
+ inputs_embeds = self.wte(input_ids)
760
+
761
+ hidden_states = inputs_embeds
762
+
763
+ presents = () if use_cache else None
764
+ all_self_attentions = () if output_attentions else None
765
+ all_hidden_states = () if output_hidden_states else None
766
+ all_idx = () if output_retrieved_memory_idx else None
767
+
768
+ if self.gradient_checkpointing and self.training:
769
+ if use_cache:
770
+ logger.warning_once(
771
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
772
+ )
773
+ use_cache = False
774
+
775
+ # Compute alibi tensor: check build_alibi_tensor documentation
776
+ seq_length_with_past = seq_length
777
+ past_key_values_length = 0
778
+ if past_key_values[0] is not None:
779
+ past_key_values_length = past_key_values[0][0].shape[2]
780
+ seq_length_with_past = seq_length_with_past + past_key_values_length
781
+ if attention_mask is None:
782
+ attention_mask = torch.ones(
783
+ (batch_size, seq_length_with_past), device=hidden_states.device
784
+ )
785
+ else:
786
+ attention_mask = attention_mask.to(hidden_states.device)
787
+
788
+ alibi = self.build_mpt_alibi_tensor(
789
+ self.num_heads,
790
+ self.config.max_seq_len,
791
+ seq_length_with_past,
792
+ device=hidden_states.device,
793
+ )
794
+ # EM: Alibi tensor for retrieved kvs
795
+ alibi_ae = self.build_mpt_alibi_tensor(
796
+ self.num_heads,
797
+ seq_length,
798
+ seq_length_with_past,
799
+ device=hidden_states.device,
800
+ for_ae=True,
801
+ topk=topk,
802
+ )
803
+
804
+ causal_mask = self._prepare_attn_mask(
805
+ attention_mask,
806
+ input_shape=(batch_size, seq_length),
807
+ past_key_values_length=past_key_values_length,
808
+ )
809
+
810
+ for i, (block, layer_past) in enumerate(zip(self.blocks, past_key_values)):
811
+ if output_hidden_states:
812
+ all_hidden_states = all_hidden_states + (hidden_states,)
813
+
814
+ long_range_past_key_value = (
815
+ long_range_past_key_values[i]
816
+ if (
817
+ long_range_past_key_values is not None
818
+ and self.use_external_mind_by_layer[i]
819
+ and use_external_mind is True
820
+ )
821
+ else None
822
+ )
823
+ if long_range_past_key_value is not None and faiss_indexes is not None:
824
+ raise NotImplementedError(
825
+ """Using faiss and passing key value pairs
826
+ manually are mutually exclusive right now."""
827
+ )
828
+ if self.gradient_checkpointing and self.training:
829
+
830
+ def create_custom_forward(module):
831
+ def custom_forward(*inputs):
832
+ # None for past_key_value
833
+ return module(
834
+ *inputs,
835
+ use_cache=use_cache,
836
+ output_attentions=output_attentions,
837
+ )
838
+
839
+ return custom_forward
840
+
841
+ outputs = torch.utils.checkpoint.checkpoint(
842
+ create_custom_forward(block),
843
+ hidden_states,
844
+ alibi,
845
+ causal_mask,
846
+ layer_past,
847
+ )
848
+ else:
849
+ outputs = block(
850
+ hidden_states,
851
+ layer_past=layer_past,
852
+ attention_mask=causal_mask,
853
+ use_cache=use_cache,
854
+ output_attentions=output_attentions,
855
+ output_retrieved_memory_idx=output_retrieved_memory_idx,
856
+ position_bias=alibi,
857
+ position_bias_ae=alibi_ae,
858
+ topk=topk,
859
+ long_range_past_key_value=long_range_past_key_value,
860
+ faiss_indexes=faiss_indexes,
861
+ mask_by_sim=self.mask_by_sim,
862
+ sim_threshold=self.sim_threshold,
863
+ current_layer=i,
864
+ )
865
+
866
+ hidden_states = outputs[0]
867
+ if use_cache is True:
868
+ presents = presents + (outputs[1],)
869
+
870
+ if output_attentions:
871
+ all_self_attentions = all_self_attentions + (
872
+ outputs[2 if use_cache else 1],
873
+ )
874
+ if output_retrieved_memory_idx:
875
+ idx = (
876
+ 3
877
+ if (use_cache & output_attentions)
878
+ else 2
879
+ if (use_cache or output_attentions)
880
+ else 1
881
+ )
882
+ all_idx = all_idx + (outputs[idx],)
883
+
884
+ # Add last hidden state
885
+ hidden_states = self.norm_f(hidden_states)
886
+
887
+ if output_hidden_states:
888
+ all_hidden_states = all_hidden_states + (hidden_states,)
889
+
890
+ if not return_dict:
891
+ return tuple(
892
+ v
893
+ for v in [
894
+ hidden_states,
895
+ presents,
896
+ all_hidden_states,
897
+ all_self_attentions,
898
+ all_idx,
899
+ ]
900
+ if v is not None
901
+ )
902
+
903
+ return BaseModelOutputWithPastAndCrossAttentions(
904
+ last_hidden_state=hidden_states,
905
+ past_key_values=presents,
906
+ hidden_states=all_hidden_states,
907
+ attentions=(all_self_attentions, all_idx), # EM: Return idx of retrieved memories
908
+ )
909
+
910
+
911
+ @add_start_docstrings(
912
+ """
913
+ The MPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
914
+ embeddings).
915
+ """,
916
+ MPT_START_DOCSTRING,
917
+ )
918
+ class ExtendedMptForCausalLM(MptPreTrainedModel):
919
+ """Extended MPT for Causal LM."""
920
+
921
+ _tied_weights_keys = ["lm_head.weight"]
922
+
923
+ def __init__(self, config: ExtendedMptConfig, external_memories=None):
924
+ super().__init__(config)
925
+ self.transformer: ExtendedMptModel = ExtendedMptModel(config)
926
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
927
+
928
+ self.use_external_mind = config.use_external_mind
929
+ self.memory_type = config.attn_config.memory_type
930
+ self.memory_ids = None
931
+ self.memories = None
932
+ self.memory_device = config.attn_config.memory_device
933
+ self.remove_special_ids = config.attn_config.remove_special_ids
934
+ self.tokenizer_all_special_ids = config.attn_config.tokenizer_all_special_ids
935
+
936
+ # EM: Memory token ids
937
+ if external_memories is not None:
938
+ self.memory_ids = external_memories
939
+ # Initialize weights and apply final processing
940
+ self.post_init()
941
+
942
+ def get_output_embeddings(self):
943
+ return self.lm_head
944
+
945
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
946
+ self.lm_head = new_embeddings
947
+
948
+ # EM: Clear memory cache
949
+ def clear_memory(self):
950
+ """Clear memory cache."""
951
+ self.memory_ids = None
952
+ self.memories = None
953
+
954
+ def prepare_inputs_for_generation(
955
+ self,
956
+ input_ids: torch.LongTensor,
957
+ past_key_values: Optional[torch.Tensor] = None,
958
+ attention_mask: Optional[torch.Tensor] = None,
959
+ inputs_embeds: Optional[torch.Tensor] = None,
960
+ use_cache: Optional[bool] = None,
961
+ **kwargs,
962
+ ) -> dict:
963
+ # only last token for input_ids if past is not None
964
+ if past_key_values:
965
+ input_ids = input_ids[:, -1].unsqueeze(-1)
966
+
967
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
968
+ if inputs_embeds is not None and past_key_values is None:
969
+ model_inputs = {"inputs_embeds": inputs_embeds}
970
+ else:
971
+ model_inputs = {"input_ids": input_ids}
972
+
973
+ model_inputs.update(
974
+ {
975
+ "past_key_values": past_key_values, # NITS should it be layer_past?
976
+ "use_cache": use_cache,
977
+ "attention_mask": attention_mask,
978
+ "use_external_mind": kwargs.get("use_external_mind"), # EM: Add config here
979
+ "topk": kwargs.get("topk"),
980
+ }
981
+ )
982
+ return model_inputs
983
+
984
+ @add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING)
985
+ @add_code_sample_docstrings(
986
+ checkpoint=_CHECKPOINT_FOR_DOC,
987
+ output_type=CausalLMOutputWithCrossAttentions,
988
+ config_class=_CONFIG_FOR_DOC,
989
+ )
990
+ def forward(
991
+ self,
992
+ input_ids: Optional[torch.LongTensor] = None,
993
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
994
+ attention_mask: Optional[torch.Tensor] = None,
995
+ inputs_embeds: Optional[torch.Tensor] = None,
996
+ labels: Optional[torch.Tensor] = None,
997
+ use_cache: Optional[bool] = None,
998
+ output_attentions: Optional[bool] = None,
999
+ output_retrieved_memory_idx: Optional[bool] = None,
1000
+ output_hidden_states: Optional[bool] = None,
1001
+ return_dict: Optional[bool] = None,
1002
+ use_external_mind: Optional[bool] = None,
1003
+ topk: int = None,
1004
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1005
+ r"""
1006
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1007
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1008
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1009
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1010
+ """
1011
+ return_dict = (
1012
+ return_dict if return_dict is not None else self.config.use_return_dict
1013
+ )
1014
+
1015
+ # EM: Generate key value cache once on first call
1016
+ if (
1017
+ self.memory_ids is not None and self.memories is None
1018
+ ):
1019
+ self.memories = self.generate_cache(
1020
+ self.memory_ids, cache_type=self.memory_type
1021
+ )
1022
+ # EM: Remove special tokens from memory cache
1023
+ if self.remove_special_ids:
1024
+ idx_to_remove = [
1025
+ token_idx
1026
+ for token_idx, token in enumerate(self.memory_ids[0])
1027
+ if token in self.tokenizer_all_special_ids
1028
+ ]
1029
+ if self.memory_type == "manual":
1030
+ mask = torch.ones(self.memories[0][0].size(), dtype=torch.bool)
1031
+ mask[:, :, idx_to_remove, :] = False
1032
+
1033
+ new_size = (
1034
+ self.memories[0][0].size(0),
1035
+ self.memories[0][0].size(1),
1036
+ -1,
1037
+ self.memories[0][0].size(3),
1038
+ )
1039
+ self.memories = [
1040
+ (ks[mask].view(new_size), vs[mask].view(new_size))
1041
+ for ks, vs in self.memories
1042
+ ]
1043
+ else:
1044
+ kn_index, kv_index = self.memories
1045
+ all_idx_to_remove = [
1046
+ [
1047
+ i
1048
+ for i in range(0, kn_index.ntotal)
1049
+ if (
1050
+ i
1051
+ % (
1052
+ kn_index.ntotal
1053
+ / (
1054
+ self.config.num_attention_heads
1055
+ * self.config.num_hidden_layers
1056
+ )
1057
+ )
1058
+ )
1059
+ == j
1060
+ ]
1061
+ for j in idx_to_remove
1062
+ ]
1063
+ kn_index.remove_ids(
1064
+ np.array(all_idx_to_remove).flatten().astype("int64")
1065
+ )
1066
+ kv_index.remove_ids(
1067
+ np.array(all_idx_to_remove).flatten().astype("int64")
1068
+ )
1069
+
1070
+ use_external_mind = (
1071
+ use_external_mind
1072
+ if use_external_mind is not None
1073
+ else self.use_external_mind
1074
+ )
1075
+ topk = topk if topk is not None else None
1076
+
1077
+ long_range_past_key_values = None
1078
+ faiss_indexes = None
1079
+ if hasattr(self, "memories") and isinstance(self.memories, list):
1080
+ long_range_past_key_values = self.memories
1081
+ elif hasattr(self, "memories"):
1082
+ faiss_indexes = self.memories
1083
+
1084
+ transformer_outputs = self.transformer(
1085
+ input_ids,
1086
+ past_key_values=past_key_values,
1087
+ attention_mask=attention_mask,
1088
+ inputs_embeds=inputs_embeds,
1089
+ use_cache=use_cache,
1090
+ output_attentions=output_attentions,
1091
+ output_retrieved_memory_idx=output_retrieved_memory_idx,
1092
+ output_hidden_states=output_hidden_states,
1093
+ return_dict=return_dict,
1094
+ long_range_past_key_values=long_range_past_key_values,
1095
+ faiss_indexes=faiss_indexes,
1096
+ use_external_mind=use_external_mind,
1097
+ topk=topk,
1098
+ )
1099
+ hidden_states = transformer_outputs[0]
1100
+
1101
+ lm_logits = self.lm_head(hidden_states)
1102
+
1103
+ loss = None
1104
+ if labels is not None:
1105
+ # move labels to correct device to enable model parallelism
1106
+ labels = labels.to(lm_logits.device)
1107
+ # Shift so that tokens < n predict n
1108
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1109
+ shift_labels = labels[..., 1:].contiguous()
1110
+ batch_size, seq_length, vocab_size = shift_logits.shape
1111
+ # Flatten the tokens
1112
+ loss_fct = CrossEntropyLoss()
1113
+ loss = loss_fct(
1114
+ shift_logits.view(batch_size * seq_length, vocab_size),
1115
+ shift_labels.view(batch_size * seq_length),
1116
+ )
1117
+
1118
+ if not return_dict:
1119
+ output = (lm_logits,) + transformer_outputs[1:]
1120
+ return ((loss,) + output) if loss is not None else output
1121
+
1122
+ return CausalLMOutputWithCrossAttentions(
1123
+ loss=loss,
1124
+ logits=lm_logits,
1125
+ past_key_values=transformer_outputs.past_key_values,
1126
+ hidden_states=transformer_outputs.hidden_states,
1127
+ attentions=transformer_outputs.attentions,
1128
+ )
1129
+
1130
+ def _reorder_cache(
1131
+ self,
1132
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...],
1133
+ beam_idx: torch.LongTensor,
1134
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1135
+ """
1136
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1137
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1138
+ beam_idx at every generation step.
1139
+
1140
+ Output shares the same memory storage as `past`.
1141
+ """
1142
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
1143
+ device_to_beam_idx = {
1144
+ past_state.device: beam_idx.to(past_state.device)
1145
+ for layer_past in past
1146
+ for past_state in layer_past
1147
+ }
1148
+ reordered_past = tuple(
1149
+ (
1150
+ layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
1151
+ layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
1152
+ )
1153
+ for layer_past in past
1154
+ )
1155
+ return reordered_past
1156
+
1157
+ # EM: Add method to generate key-value cache
1158
+ def generate_cache(
1159
+ self,
1160
+ input_ids: torch.LongTensor,
1161
+ stride: int = 512,
1162
+ max_len: int = 3072,
1163
+ cache_type: str = "manual",
1164
+ ):
1165
+ """Generate cache for long range attention."""
1166
+ if cache_type not in ["manual", "faiss"]:
1167
+ raise NotImplementedError(f"Cache type {cache_type} not implemented.")
1168
+
1169
+ prev_end_loc = 0
1170
+ long_range_past_key_values = None
1171
+ faiss_indexes = None
1172
+ for b_idx in range(
1173
+ 0, input_ids.size(-1), stride
1174
+ ): # generate kv-pairs using stride
1175
+ end_loc = min(b_idx + max_len, input_ids.size(-1))
1176
+ trg_len = end_loc - prev_end_loc
1177
+ subseq = input_ids[:, b_idx:end_loc].to(self.device)
1178
+ with torch.no_grad():
1179
+ outputs = self.transformer(
1180
+ subseq, use_cache=True, use_external_mind=False
1181
+ )
1182
+ to_cache = [
1183
+ (kv[0][:, :, -trg_len:], kv[1][:, :, -trg_len:])
1184
+ for kv in outputs.past_key_values
1185
+ ]
1186
+ long_range_past_key_values, faiss_indexes = self.cache(
1187
+ to_cache,
1188
+ cache_type,
1189
+ long_range_past_key_values=long_range_past_key_values,
1190
+ faiss_indexes=faiss_indexes,
1191
+ )
1192
+
1193
+ prev_end_loc = end_loc
1194
+ if end_loc == input_ids.size(-1):
1195
+ break
1196
+ if long_range_past_key_values is not None:
1197
+ return long_range_past_key_values
1198
+ else:
1199
+ return faiss_indexes
1200
+
1201
+ # EM: Add method to cache key value pairs
1202
+ def cache(
1203
+ self,
1204
+ to_cache: list,
1205
+ cache_type: str = "manual",
1206
+ long_range_past_key_values: list = None,
1207
+ faiss_indexes: faiss.IndexFlatIP = None,
1208
+ max_length_cache=100000,
1209
+ verbose=False,
1210
+ ):
1211
+ """Cache long range attention."""
1212
+ if (long_range_past_key_values is not None) & (faiss_indexes is not None):
1213
+ raise NotImplementedError(
1214
+ "Using faiss and passing key value pairs manually are mutually exclusive right now."
1215
+ )
1216
+
1217
+ # To avoid spinning up a new index for each layer, we add one-hot encodings to the keys so that queries match with the appropriate layer, head
1218
+ if cache_type == "faiss": # add one-hot encoding to match layer, head indices
1219
+ one_hot_encodings = (
1220
+ F.one_hot(torch.arange(0, self.config.n_heads * self.config.n_layers))
1221
+ * 10
1222
+ )
1223
+ # New indices, one to store normalized keys with one-hot encodings, another to retrieve kv pairs without normalization
1224
+ if faiss_indexes is None:
1225
+ faiss_indexes = (
1226
+ faiss.IndexFlatIP(
1227
+ to_cache[0][0].size(-1) + one_hot_encodings.size(-1)
1228
+ ),
1229
+ faiss.IndexFlatIP(to_cache[0][0].size(-1) * 2),
1230
+ )
1231
+ kn_index, kv_index = faiss_indexes
1232
+ for l_idx, (k, v) in enumerate(to_cache):
1233
+ k_n = (k / vector_norm(k, ord=2, dim=-1, keepdim=True)).to("cpu") #Normalize keys for cosine sim
1234
+
1235
+ # Indices are 2 dimensional, so flatten
1236
+ # Add normalized keys with one-hot encodings
1237
+ k_n = torch.concat(
1238
+ [
1239
+ rearrange(k_n, "b h s d -> b (h s) d", h=self.config.n_heads),
1240
+ one_hot_encodings[
1241
+ self.config.n_heads
1242
+ * l_idx : self.config.n_heads
1243
+ * (l_idx + 1)
1244
+ ]
1245
+ .unsqueeze(0)
1246
+ .repeat_interleave(repeats=k.size(-2), dim=-2),
1247
+ ],
1248
+ dim=-1,
1249
+ )
1250
+ kn_index.add(k_n.squeeze().numpy())
1251
+
1252
+ # Add unnormalized keys and values
1253
+ k = rearrange(k, "b h s d -> b (h s) d", h=self.config.n_heads)
1254
+ v = rearrange(v, "b h s d -> b (h s) d", h=self.config.n_heads)
1255
+ kv_index.add(
1256
+ torch.concat([k.squeeze(), v.squeeze()], dim=1).to("cpu").numpy()
1257
+ )
1258
+ else:
1259
+ # Simply use list to store key value pairs
1260
+ if long_range_past_key_values is None:
1261
+ long_range_past_key_values = [
1262
+ (k.to(self.memory_device), v.to(self.memory_device))
1263
+ for k, v in to_cache
1264
+ ]
1265
+ else:
1266
+ long_range_past_key_values = [
1267
+ (
1268
+ torch.concat(
1269
+ [kv[0], to_cache[ind][0].to(self.memory_device)], dim=2
1270
+ ),
1271
+ torch.concat(
1272
+ [kv[1], to_cache[ind][1].to(self.memory_device)], dim=2
1273
+ ),
1274
+ )
1275
+ for ind, kv in enumerate(long_range_past_key_values)
1276
+ ]
1277
+ if (
1278
+ long_range_past_key_values is not None
1279
+ ): # set a limit on manual memory length
1280
+ if long_range_past_key_values[0][0].size(-2) > max_length_cache:
1281
+ long_range_past_key_values = [
1282
+ (
1283
+ kv[0][:, :, -max_length_cache:],
1284
+ kv[1][:, :, -max_length_cache:],
1285
+ )
1286
+ for kv in long_range_past_key_values
1287
+ ]
1288
+ if verbose:
1289
+ if cache_type == "faiss":
1290
+ print(f"{kn_index.ntotal} keys in faiss index")
1291
+ else:
1292
+ print(f"{long_range_past_key_values[0][0].size(-2)} cached kvs")
1293
+
1294
+ return (
1295
+ long_range_past_key_values,
1296
+ (kn_index, kv_index) if cache_type == "faiss" else None,
1297
+ )