phoebeklett
commited on
Commit
•
6d969f3
1
Parent(s):
03f1d2a
Upload 2 files
Browse files- configuration.py +5 -4
- modeling.py +4 -3
configuration.py
CHANGED
@@ -101,6 +101,7 @@ class ExtendedMptAttentionConfig(PretrainedConfig):
|
|
101 |
sim_threshold=0.25,
|
102 |
tokenizer_all_special_ids=[0, 50278],
|
103 |
remove_special_ids=False,
|
|
|
104 |
**kwargs,
|
105 |
):
|
106 |
super().__init__(**kwargs)
|
@@ -121,6 +122,7 @@ class ExtendedMptAttentionConfig(PretrainedConfig):
|
|
121 |
self.sim_threshold = sim_threshold
|
122 |
self.tokenizer_all_special_ids = tokenizer_all_special_ids
|
123 |
self.remove_special_ids = remove_special_ids
|
|
|
124 |
|
125 |
if attn_type not in ["multihead_attention", "multiquery_attention"]:
|
126 |
raise ValueError(
|
@@ -245,7 +247,6 @@ class ExtendedMptConfig(PretrainedConfig):
|
|
245 |
n_layers: int = 32,
|
246 |
expansion_ratio: int = 4,
|
247 |
max_seq_len_inference: int = 2048,
|
248 |
-
max_seq_len_train: int = 2048,
|
249 |
vocab_size: int = 50432,
|
250 |
resid_pdrop: float = 0.0,
|
251 |
layer_norm_epsilon: float = 1e-5,
|
@@ -261,11 +262,12 @@ class ExtendedMptConfig(PretrainedConfig):
|
|
261 |
use_cache: bool = False,
|
262 |
initializer_range=0.02,
|
263 |
use_external_mind: bool = True,
|
264 |
-
use_external_mind_by_layer: list[bool] = [True for _ in range(32)],
|
265 |
**kwargs,
|
266 |
):
|
267 |
if attn_config is None:
|
268 |
-
self.attn_config = ExtendedMptAttentionConfig(
|
|
|
|
|
269 |
elif not isinstance(attn_config, ExtendedMptAttentionConfig):
|
270 |
self.attn_config = ExtendedMptAttentionConfig(**attn_config)
|
271 |
else:
|
@@ -275,7 +277,6 @@ class ExtendedMptConfig(PretrainedConfig):
|
|
275 |
self.n_layers = n_layers
|
276 |
self.expansion_ratio = expansion_ratio
|
277 |
self.max_seq_len = max_seq_len_inference
|
278 |
-
self.max_seq_len_train = max_seq_len_train
|
279 |
self.vocab_size = vocab_size
|
280 |
self.resid_pdrop = resid_pdrop
|
281 |
self.emb_pdrop = emb_pdrop
|
|
|
101 |
sim_threshold=0.25,
|
102 |
tokenizer_all_special_ids=[0, 50278],
|
103 |
remove_special_ids=False,
|
104 |
+
use_external_mind_by_layer: list[bool] = [True for _ in range(32)],
|
105 |
**kwargs,
|
106 |
):
|
107 |
super().__init__(**kwargs)
|
|
|
122 |
self.sim_threshold = sim_threshold
|
123 |
self.tokenizer_all_special_ids = tokenizer_all_special_ids
|
124 |
self.remove_special_ids = remove_special_ids
|
125 |
+
self.use_external_mind_by_layer = use_external_mind_by_layer
|
126 |
|
127 |
if attn_type not in ["multihead_attention", "multiquery_attention"]:
|
128 |
raise ValueError(
|
|
|
247 |
n_layers: int = 32,
|
248 |
expansion_ratio: int = 4,
|
249 |
max_seq_len_inference: int = 2048,
|
|
|
250 |
vocab_size: int = 50432,
|
251 |
resid_pdrop: float = 0.0,
|
252 |
layer_norm_epsilon: float = 1e-5,
|
|
|
262 |
use_cache: bool = False,
|
263 |
initializer_range=0.02,
|
264 |
use_external_mind: bool = True,
|
|
|
265 |
**kwargs,
|
266 |
):
|
267 |
if attn_config is None:
|
268 |
+
self.attn_config = ExtendedMptAttentionConfig(
|
269 |
+
use_external_mind_by_layer=[True for _ in range(n_layers)]
|
270 |
+
)
|
271 |
elif not isinstance(attn_config, ExtendedMptAttentionConfig):
|
272 |
self.attn_config = ExtendedMptAttentionConfig(**attn_config)
|
273 |
else:
|
|
|
277 |
self.n_layers = n_layers
|
278 |
self.expansion_ratio = expansion_ratio
|
279 |
self.max_seq_len = max_seq_len_inference
|
|
|
280 |
self.vocab_size = vocab_size
|
281 |
self.resid_pdrop = resid_pdrop
|
282 |
self.emb_pdrop = emb_pdrop
|
modeling.py
CHANGED
@@ -42,7 +42,7 @@ from transformers.modeling_outputs import (
|
|
42 |
from transformers.modeling_utils import PreTrainedModel
|
43 |
from transformers.utils import logging
|
44 |
|
45 |
-
from .configuration import ExtendedMptConfig
|
46 |
|
47 |
logger = logging.get_logger(__name__)
|
48 |
|
@@ -920,7 +920,7 @@ class ExtendedMptForCausalLM(MptPreTrainedModel):
|
|
920 |
|
921 |
_tied_weights_keys = ["lm_head.weight"]
|
922 |
|
923 |
-
def __init__(self, config: ExtendedMptConfig, external_memories=None):
|
924 |
super().__init__(config)
|
925 |
self.transformer: ExtendedMptModel = ExtendedMptModel(config)
|
926 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
@@ -1016,8 +1016,9 @@ class ExtendedMptForCausalLM(MptPreTrainedModel):
|
|
1016 |
if (
|
1017 |
self.memory_ids is not None and self.memories is None
|
1018 |
):
|
|
|
1019 |
self.memories = self.generate_cache(
|
1020 |
-
self.memory_ids, cache_type=self.memory_type
|
1021 |
)
|
1022 |
# EM: Remove special tokens from memory cache
|
1023 |
if self.remove_special_ids:
|
|
|
42 |
from transformers.modeling_utils import PreTrainedModel
|
43 |
from transformers.utils import logging
|
44 |
|
45 |
+
from emts_clean.src.mpt.configuration import ExtendedMptConfig
|
46 |
|
47 |
logger = logging.get_logger(__name__)
|
48 |
|
|
|
920 |
|
921 |
_tied_weights_keys = ["lm_head.weight"]
|
922 |
|
923 |
+
def __init__(self, config: ExtendedMptConfig, external_memories:list=None):
|
924 |
super().__init__(config)
|
925 |
self.transformer: ExtendedMptModel = ExtendedMptModel(config)
|
926 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
1016 |
if (
|
1017 |
self.memory_ids is not None and self.memories is None
|
1018 |
):
|
1019 |
+
self.memory_ids = torch.tensor([self.memory_ids], device=self.device) if type(self.memory_ids)==list else self.memory_ids
|
1020 |
self.memories = self.generate_cache(
|
1021 |
+
self.memory_ids, cache_type=self.memory_type,
|
1022 |
)
|
1023 |
# EM: Remove special tokens from memory cache
|
1024 |
if self.remove_special_ids:
|