CongMa / models /shared.py
XuBailing's picture
Upload 243 files
107f987
raw
history blame
1.73 kB
import sys
from typing import Any
from models.loader.args import parser
from models.loader import LoaderCheckPoint
from configs.model_config import (llm_model_dict, LLM_MODEL)
from models.base import BaseAnswer
loaderCheckPoint: LoaderCheckPoint = None
def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_v2: bool = False) -> Any:
"""
init llm_model_ins LLM
:param llm_model: model_name
:param no_remote_model: remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model
:param use_ptuning_v2: Use p-tuning-v2 PrefixEncoder
:return:
"""
pre_model_name = loaderCheckPoint.model_name
llm_model_info = llm_model_dict[pre_model_name]
if no_remote_model:
loaderCheckPoint.no_remote_model = no_remote_model
if use_ptuning_v2:
loaderCheckPoint.use_ptuning_v2 = use_ptuning_v2
if llm_model:
llm_model_info = llm_model_dict[llm_model]
if loaderCheckPoint.no_remote_model:
loaderCheckPoint.model_name = llm_model_info['name']
else:
loaderCheckPoint.model_name = llm_model_info['pretrained_model_name']
loaderCheckPoint.model_path = llm_model_info["local_model_path"]
if 'FastChatOpenAILLM' in llm_model_info["provides"]:
loaderCheckPoint.unload_model()
else:
loaderCheckPoint.reload_model()
provides_class = getattr(sys.modules['models'], llm_model_info['provides'])
modelInsLLM = provides_class(checkPoint=loaderCheckPoint)
if 'FastChatOpenAILLM' in llm_model_info["provides"]:
modelInsLLM.set_api_base_url(llm_model_info['api_base_url'])
modelInsLLM.call_model_name(llm_model_info['name'])
return modelInsLLM