attempt xformers hijack attention
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -43,6 +43,10 @@ def load_model(
|
|
43 |
|
44 |
logging.info("patching with flash attention")
|
45 |
replace_llama_attn_with_flash_attn()
|
|
|
|
|
|
|
|
|
46 |
|
47 |
torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
|
48 |
try:
|
|
|
43 |
|
44 |
logging.info("patching with flash attention")
|
45 |
replace_llama_attn_with_flash_attn()
|
46 |
+
elif is_llama_derived_model and cfg.xformers_attention:
|
47 |
+
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
|
48 |
+
logging.info("patching with xformers attention")
|
49 |
+
hijack_llama_attention()
|
50 |
|
51 |
torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
|
52 |
try:
|