from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
llm='' | |
chat_model='' | |
HF_TOKEN=os.getenv("HF_TOKEN") | |
def init_llama_chatmodel(repo_id): | |
global llm | |
global chat_model | |
llm = HuggingFaceEndpoint( | |
repo_id=repo_id, | |
task="text-generation" | |
) | |
chat_model = ChatHuggingFace(llm=llm) | |
return chat_model | |
# def get_llama_endpoint(repo_id): | |
# llm = HuggingFaceEndpoint( | |
# repo_id=repo_id, | |
# task="text-generation", | |
# max_new_tokens=512, | |
# do_sample=False, | |
# repetition_penalty=1.03, | |
# ) | |
# return llm | |
# def get_llama_chatmodel(llm): | |
# chat_model = ChatHuggingFace(llm=llm) | |
# return chat_model | |