Are Lora weights loaded?
#4
by
syncdoth
- opened
- I am not sure if the
model = LlamaForCausalLM.from_pretrained("chainyo/alpaca-lora-7b")
is enough to load to lora weights inadapter_model.bin
too. I have loaded the model using the example, and looking at thestate_dict:
model = LlamaForCausalLM.from_pretrained("chainyo/alpaca-lora-7b")
state_dict = model.state_dict()
print([k for k in state_dict.keys() if 'lora' in k])
>>> []
While the state_dict in adapter_model.bin
are:
adapter_dict = torch.load('adapter_model.bin')
print(adapter_dict.keys())
>>> dict_keys(['base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight', 'base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight', ...
If this is the case, then the loading script should be sth like:
from transformers import LlamaTokenizer, LlamaForCausalLM
from peft import PeftModel, PeftConfig
peft_model_id = "chainyo/alpaca-lora-7b"
config = PeftConfig.from_pretrained(peft_model_id)
model = LlamaForCausalLM.from_pretrained(config.base_model_name_or_path,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto")
model = PeftModel.from_pretrained(model, peft_model_id)
tokenizer = LlamaTokenizer.from_pretrained(peft_model_id)
Also, the base_model_name_or_path
in adapter_config.json
might need to be corrected so it doesn't download original llama weights again.
- On second thought, are lora weights required to perform as expected? Maybe the LLaMa weights themselves are tuned... (then why the name Lora?)
Thanks!
On a deeper investigation, I found that only 'self_attn.q_proj.weight', 'self_attn.rotary_emb.inv_freq', 'self_attn.v_proj.weight'
diverge from the original llama weight. Since lora is applied to q_proj
and v_proj
, I suppose these weights are re-constructed from lora A and B matrices (found PEFT's merge_and_unload
function).
syncdoth
changed discussion status to
closed
I chose to share the adapters and the fully merged Lora weights with the base model.
You can choose to load only the adapters or the full model.
- Full model:
tokenizer = LlamaTokenizer.from_pretrained("chainyo/alpaca-lora-7b")
model = LlamaForCausalLM.from_pretrained(
"chainyo/alpaca-lora-7b",
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
)
- Using the adapters:
model = LlamaForCausalLM.from_pretrained(
"decapoda-research/llama-7b-hf",
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
"chainyo/alpaca-lora-7b",
torch_dtype=torch.float16,
)