from typing import Union | |
import torch | |
from peft import PeftModel | |
from peft.tuners.lora import LoraModel | |
from transformers import LlamaForCausalLM as ModelCls | |
from transformers import LlamaTokenizerFast as TkCls | |
PeftCls = Union[PeftModel, LoraModel] | |
orig_model = "TheBloke/Llama-2-7B-Chat-fp16" | |
lora_model = "models/Llama-7B-TwAddr-LoRA" | |
output_dir = "models/Llama-7B-TwAddr-Merged" | |
model = ModelCls.from_pretrained( | |
orig_model, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
) | |
# Due to generation config validation. | |
model.generation_config.temperature = 1.0 | |
model.generation_config.top_p = 1.0 | |
model: PeftCls = PeftModel.from_pretrained( | |
model, | |
lora_model, | |
torch_dtype=torch.float16, | |
) | |
model = model.merge_and_unload() | |
model.save_pretrained( | |
output_dir, | |
safe_serialization=True, | |
) | |
# Tokenizer 也要跟著另外存一份 | |
tk: TkCls = TkCls.from_pretrained(orig_model) | |
tk.save_pretrained(output_dir) | |