don't use mask expansion for inference (#392)
Browse files- examples/llama-2/lora.yml +1 -0
- examples/llama-2/qlora.yml +1 -0
- src/axolotl/utils/models.py +4 -2
examples/llama-2/lora.yml
CHANGED
@@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf
|
|
2 |
base_model_config: meta-llama/Llama-2-7b-hf
|
3 |
model_type: LlamaForCausalLM
|
4 |
tokenizer_type: LlamaTokenizer
|
|
|
5 |
|
6 |
load_in_8bit: true
|
7 |
load_in_4bit: false
|
|
|
2 |
base_model_config: meta-llama/Llama-2-7b-hf
|
3 |
model_type: LlamaForCausalLM
|
4 |
tokenizer_type: LlamaTokenizer
|
5 |
+
is_llama_derived_model: true
|
6 |
|
7 |
load_in_8bit: true
|
8 |
load_in_4bit: false
|
examples/llama-2/qlora.yml
CHANGED
@@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf
|
|
2 |
base_model_config: meta-llama/Llama-2-7b-hf
|
3 |
model_type: LlamaForCausalLM
|
4 |
tokenizer_type: LlamaTokenizer
|
|
|
5 |
|
6 |
load_in_8bit: false
|
7 |
load_in_4bit: true
|
|
|
2 |
base_model_config: meta-llama/Llama-2-7b-hf
|
3 |
model_type: LlamaForCausalLM
|
4 |
tokenizer_type: LlamaTokenizer
|
5 |
+
is_llama_derived_model: true
|
6 |
|
7 |
load_in_8bit: false
|
8 |
load_in_4bit: true
|
src/axolotl/utils/models.py
CHANGED
@@ -138,8 +138,10 @@ def load_model(
|
|
138 |
LOG.info("patching with xpos rope")
|
139 |
replace_llama_rope_with_xpos_rope()
|
140 |
|
141 |
-
if
|
142 |
-
cfg.
|
|
|
|
|
143 |
):
|
144 |
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
145 |
|
|
|
138 |
LOG.info("patching with xpos rope")
|
139 |
replace_llama_rope_with_xpos_rope()
|
140 |
|
141 |
+
if (
|
142 |
+
cfg.is_llama_derived_model
|
143 |
+
and (cfg.max_packed_sequence_len or cfg.sample_packing)
|
144 |
+
and not cfg.inference
|
145 |
):
|
146 |
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
147 |
|