yunzi7's picture
pr #8 (#8)
1c1b26d verified
raw
history blame contribute delete
765 Bytes
import keras_nlp
MODEL_NAME = "gemma2_instruct_2b_en"
#LORA_WEIGHT_PATH = "ice_breaking_challenge/models/gemma2_it_2b_icebreaking_quiz_v2_3.lora.h5"
LORA_WEIGHT_PATH = "ice_breaking_challenge/models/gemma2_it_2b_icebreaking_quiz_v2_5.lora.h5"
def load_model_with_lora(model_name:str = MODEL_NAME, lora_weight_path: str = LORA_WEIGHT_PATH):
"""
Keras ๊ธฐ๋ฐ˜ ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ LoRA ๊ฐ€์ค‘์น˜ ์ ์šฉ
Args:
model_name (str): ๋กœ๋“œํ•  ๋ชจ๋ธ์˜ ์ด๋ฆ„
lora_weight_path (str): ์ ์šฉํ•  LoRA ๊ฐ€์ค‘์น˜ ํŒŒ์ผ์˜ ๊ฒฝ๋กœ
Returns:
keras_nlp.models.GemmaCausalLM: ๋กœ๋“œ๋œ ๋ชจ๋ธ
"""
model = keras_nlp.models.GemmaCausalLM.from_preset(model_name)
model.backbone.load_lora_weights(lora_weight_path)
return model