import torch from peft import PeftModel # Ensure you have 'peft' library or modify according to your setup import os from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig import argparse from utils import get_logger # Ensure this is implemented in your environment import json logger = get_logger("merge", "info") def smart_tokenizer_and_embedding_resize(tokenizer, model, custom_tokens_path=None): """Resize tokenizer and embedding to accommodate new tokens.""" special_tokens_dict = { "pad_token": "[PAD]", "eos_token": "", "bos_token": "", "unk_token": "" } # Load custom tokens if specified custom_tokens = [] if custom_tokens_path is not None: with open(custom_tokens_path, 'r') as file: custom_tokens = [line.strip() for line in file.readlines()] num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) if custom_tokens: num_added_toks += tokenizer.add_tokens(custom_tokens, special_tokens=True) model.resize_token_embeddings(len(tokenizer)) logger.info(f"Resized tokenizer and model embeddings. Added {num_added_toks} tokens.") def main(): parser = argparse.ArgumentParser() parser.add_argument("-bm", "--base_model", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Base model name or path") parser.add_argument("-lm", "--lora_model", type=str, required=True, help="Path to the Lora model directory") parser.add_argument("-o", "--output", type=str, required=True, help="Output directory for the merged model") parser.add_argument("--custom_tokens", type=str, default=None, help="Path to a file containing custom tokens") args = parser.parse_args() if not os.path.exists(args.lora_model): raise FileNotFoundError(f"LoRA model directory {args.lora_model} not found.") os.makedirs(args.output, exist_ok=True) # Load the base model and tokenizer model = AutoModelForCausalLM.from_pretrained(args.base_model) tokenizer = AutoTokenizer.from_pretrained(args.base_model) # Adjust tokenizer and model for any additional tokens smart_tokenizer_and_embedding_resize(tokenizer, model, args.custom_tokens) # Load and merge the LoRA model logger.info("Loading and merging the LoRA model...") lora_model = PeftModel.from_pretrained(model, args.lora_model, merge_with_base=True) # Save the merged model and tokenizer lora_model.save_pretrained(args.output) tokenizer.save_pretrained(args.output) logger.info(f"Merged model saved to {args.output}") if __name__ == "__main__": main()