winglian commited on
Commit
1687be6
1 Parent(s): 41ecb45

don't use mask expansion for inference (#392)

Browse files
examples/llama-2/lora.yml CHANGED
@@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf
2
  base_model_config: meta-llama/Llama-2-7b-hf
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
 
5
 
6
  load_in_8bit: true
7
  load_in_4bit: false
 
2
  base_model_config: meta-llama/Llama-2-7b-hf
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
5
+ is_llama_derived_model: true
6
 
7
  load_in_8bit: true
8
  load_in_4bit: false
examples/llama-2/qlora.yml CHANGED
@@ -2,6 +2,7 @@ base_model: meta-llama/Llama-2-7b-hf
2
  base_model_config: meta-llama/Llama-2-7b-hf
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
 
5
 
6
  load_in_8bit: false
7
  load_in_4bit: true
 
2
  base_model_config: meta-llama/Llama-2-7b-hf
3
  model_type: LlamaForCausalLM
4
  tokenizer_type: LlamaTokenizer
5
+ is_llama_derived_model: true
6
 
7
  load_in_8bit: false
8
  load_in_4bit: true
src/axolotl/utils/models.py CHANGED
@@ -138,8 +138,10 @@ def load_model(
138
  LOG.info("patching with xpos rope")
139
  replace_llama_rope_with_xpos_rope()
140
 
141
- if cfg.is_llama_derived_model and (
142
- cfg.max_packed_sequence_len or cfg.sample_packing
 
 
143
  ):
144
  from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
145
 
 
138
  LOG.info("patching with xpos rope")
139
  replace_llama_rope_with_xpos_rope()
140
 
141
+ if (
142
+ cfg.is_llama_derived_model
143
+ and (cfg.max_packed_sequence_len or cfg.sample_packing)
144
+ and not cfg.inference
145
  ):
146
  from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
147