CineAI's picture
Update llm/huggingfacehub/hf_model.py
e805397 verified
raw
history blame
4.77 kB
import logging
import os
import yaml
from abc import ABC
from llm.llm_interface import LLMInterface
from llm.config import config
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms import HuggingFaceEndpoint
logger = logging.getLogger(__name__)
logger.setLevel(logging.CRITICAL) # because if something went wrong in execution application can't be work anymore
file_handler = logging.FileHandler(
"logs/chelsea_llm_huggingfacehub.log") # for all modules here template for logs file is "../logs/chelsea_{module_name}_{dir_name}.log"
logger.setLevel(logging.INFO) # informed
formatted = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatted)
logger.addHandler(file_handler)
logger.info("Getting information from hf_model module")
path_to_prompts = os.path.join(os.getcwd(), "/llm/")
print(path_to_prompts)
class HF_Mistaril(LLMInterface, ABC):
def __init__(self, prompt_entity: str, prompt_id: int = 0):
self.prompt_entity = prompt_entity
self.prompt_id = prompt_id
self.model_config = config["HF_Mistrail"]
self.llm = HuggingFaceEndpoint(
repo_id=self.model_config["model"],
model_kwargs={"temperature": self.model_config["temperature"],
"max_new_tokens": self.model_config["max_new_tokens"],
"top_k": self.model_config["top_k"], "load_in_8bit": self.model_config["load_in_8bit"]})
@staticmethod
def __read_yaml():
try:
yaml_file = os.path.join(path_to_prompts, 'prompts.yaml')
with open(yaml_file, 'r') as file:
data = yaml.safe_load(file)
return data
except Exception as e:
print(f"Execution filed : {e}")
logger.error(msg="Execution filed", exc_info=e)
def execution(self):
try:
data = self.__read_yaml()
prompts = data["prompts"][
self.prompt_id] #get second prompt from yaml, need change id parameter to get other prompt
template = prompts["prompt_template"]
prompt = PromptTemplate(template=template, input_variables=["entity"])
llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True)
output = llm_chain.invoke(self.prompt_entity)
return output["text"]
except Exception as e:
print(f"Execution filed : {e}")
logger.critical(msg="Execution filed", exc_info=e)
def __str__(self):
return f"prompt_entity={self.prompt_entity}, prompt_id={self.prompt_id}"
def __repr__(self):
return f"{self.__class__.__name__}(prompt_entity: {type(self.prompt_entity)} = {self.prompt_entity}, prompt_id: {type(self.prompt_id)} = {self.prompt_id})"
class HF_TinyLlama(LLMInterface, ABC):
def __init__(self, prompt_entity: str, prompt_id: int = 0):
self.prompt_entity = prompt_entity
self.prompt_id = prompt_id
self.model_config = config["HF_TinyLlama"]
self.llm = HuggingFaceEndpoint(
repo_id=self.model_config["model"],
model_kwargs={"temperature": self.model_config["temperature"],
"max_new_tokens": self.model_config["max_new_tokens"],
"top_k": self.model_config["top_k"], "load_in_8bit": self.model_config["load_in_8bit"]})
@staticmethod
def __read_yaml():
try:
yaml_file = os.path.join(path_to_prompts, 'prompts.yaml')
with open(yaml_file, 'r') as file:
data = yaml.safe_load(file)
return data
except Exception as e:
print(f"Execution filed : {e}")
logger.error(msg="Execution filed", exc_info=e)
def execution(self):
try:
data = self.__read_yaml()
prompts = data["prompts"][
self.prompt_id] #get second prompt from yaml, need change id parameter to get other prompt
template = prompts["prompt_template"]
prompt = PromptTemplate(template=template, input_variables=["entity"])
llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True)
output = llm_chain.invoke(self.prompt_entity)
return output["text"]
except Exception as e:
print(f"Execution filed : {e}")
logger.critical(msg="Execution filed", exc_info=e)
def __str__(self):
return f"prompt_entity={self.prompt_entity}, prompt_id={self.prompt_id}"
def __repr__(self):
return f"{self.__class__.__name__}(prompt_entity: {type(self.prompt_entity)} = {self.prompt_entity}, prompt_id: {type(self.prompt_id)} = {self.prompt_id})"