Refactor landmark attention patch
Browse files
src/axolotl/monkeypatch/llama_landmark_attn.py
CHANGED
@@ -1593,3 +1593,12 @@ def add_mem_tokens(example, mem_freq, mem_id):
|
|
1593 |
ret.extend(x[prev_idx:])
|
1594 |
# drop attention_mask
|
1595 |
return {"input_ids": ret}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1593 |
ret.extend(x[prev_idx:])
|
1594 |
# drop attention_mask
|
1595 |
return {"input_ids": ret}
|
1596 |
+
|
1597 |
+
|
1598 |
+
def patch_llama_with_landmark_attn():
|
1599 |
+
import transformers
|
1600 |
+
|
1601 |
+
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
|
1602 |
+
transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
|
1603 |
+
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
1604 |
+
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
src/axolotl/utils/models.py
CHANGED
@@ -19,15 +19,6 @@ from transformers import ( # noqa: F401
|
|
19 |
LlamaConfig,
|
20 |
)
|
21 |
|
22 |
-
try:
|
23 |
-
from transformers import ( # pylint: disable=unused-import # noqa: F401
|
24 |
-
LlamaForCausalLM,
|
25 |
-
)
|
26 |
-
except ImportError:
|
27 |
-
logging.warning(
|
28 |
-
"This version of transformers does not support Llama. Consider upgrading."
|
29 |
-
)
|
30 |
-
|
31 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
32 |
|
33 |
if TYPE_CHECKING:
|
@@ -118,14 +109,15 @@ def load_model(
|
|
118 |
logging.info("patching with sdp attention")
|
119 |
hijack_llama_sdp_attention()
|
120 |
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
121 |
-
from axolotl.monkeypatch.llama_landmark_attn import (
|
122 |
MEM_TOKEN,
|
123 |
-
|
124 |
)
|
125 |
|
126 |
logging.info("patching with landmark attention")
|
|
|
127 |
|
128 |
-
#
|
129 |
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
130 |
|
131 |
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
@@ -211,6 +203,13 @@ def load_model(
|
|
211 |
)
|
212 |
load_in_8bit = False
|
213 |
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
config = LlamaConfig.from_pretrained(base_model_config)
|
215 |
model = LlamaForCausalLM.from_pretrained(
|
216 |
base_model,
|
|
|
19 |
LlamaConfig,
|
20 |
)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
23 |
|
24 |
if TYPE_CHECKING:
|
|
|
109 |
logging.info("patching with sdp attention")
|
110 |
hijack_llama_sdp_attention()
|
111 |
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
112 |
+
from axolotl.monkeypatch.llama_landmark_attn import (
|
113 |
MEM_TOKEN,
|
114 |
+
patch_llama_with_landmark_attn,
|
115 |
)
|
116 |
|
117 |
logging.info("patching with landmark attention")
|
118 |
+
patch_llama_with_landmark_attn()
|
119 |
|
120 |
+
# Note: This might overwrite previous additional_special_tokens
|
121 |
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
122 |
|
123 |
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
|
|
203 |
)
|
204 |
load_in_8bit = False
|
205 |
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
|
206 |
+
try:
|
207 |
+
from transformers import LlamaForCausalLM
|
208 |
+
except ImportError:
|
209 |
+
logging.warning(
|
210 |
+
"This version of transformers does not support Llama. Consider upgrading."
|
211 |
+
)
|
212 |
+
|
213 |
config = LlamaConfig.from_pretrained(base_model_config)
|
214 |
model = LlamaForCausalLM.from_pretrained(
|
215 |
base_model,
|