Spaces:
Runtime error
Runtime error
import keras_nlp | |
MODEL_NAME = "gemma2_instruct_2b_en" | |
#LORA_WEIGHT_PATH = "ice_breaking_challenge/models/gemma2_it_2b_icebreaking_quiz_v2_3.lora.h5" | |
LORA_WEIGHT_PATH = "ice_breaking_challenge/models/gemma2_it_2b_icebreaking_quiz_v2_5.lora.h5" | |
def load_model_with_lora(model_name:str = MODEL_NAME, lora_weight_path: str = LORA_WEIGHT_PATH): | |
""" | |
Keras ๊ธฐ๋ฐ ๋ชจ๋ธ ๋ก๋ ๋ฐ LoRA ๊ฐ์ค์น ์ ์ฉ | |
Args: | |
model_name (str): ๋ก๋ํ ๋ชจ๋ธ์ ์ด๋ฆ | |
lora_weight_path (str): ์ ์ฉํ LoRA ๊ฐ์ค์น ํ์ผ์ ๊ฒฝ๋ก | |
Returns: | |
keras_nlp.models.GemmaCausalLM: ๋ก๋๋ ๋ชจ๋ธ | |
""" | |
model = keras_nlp.models.GemmaCausalLM.from_preset(model_name) | |
model.backbone.load_lora_weights(lora_weight_path) | |
return model | |