|
import logging |
|
import os |
|
import yaml |
|
|
|
from abc import ABC |
|
|
|
from llm.hf_interface import HFInterface |
|
from llm.config import config |
|
|
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import LLMChain |
|
from langchain.llms import HuggingFaceEndpoint |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
logger.setLevel(logging.CRITICAL) |
|
|
|
file_handler = logging.FileHandler( |
|
"logs/chelsea_llm_huggingfacehub.log") |
|
logger.setLevel(logging.INFO) |
|
|
|
formatted = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") |
|
file_handler.setFormatter(formatted) |
|
|
|
logger.addHandler(file_handler) |
|
|
|
logger.info("Getting information from hf_model module") |
|
|
|
path_to_prompts = os.path.join(os.getcwd(), "/llm/") |
|
print(path_to_prompts) |
|
|
|
|
|
class HF_Mistaril(HFInterface, ABC): |
|
def __init__(self, prompt_entity: str, prompt_id: int = 0): |
|
self.prompt_entity = prompt_entity |
|
self.prompt_id = prompt_id |
|
|
|
self.model_config = config["HF_Mistrail"] |
|
|
|
self.llm = HuggingFaceEndpoint( |
|
repo_id=self.model_config["model"], |
|
model_kwargs={"temperature": self.model_config["temperature"], |
|
"max_new_tokens": self.model_config["max_new_tokens"], |
|
"top_k": self.model_config["top_k"], "load_in_8bit": self.model_config["load_in_8bit"]}) |
|
|
|
@staticmethod |
|
def __read_yaml(): |
|
try: |
|
yaml_file = os.path.join(path_to_prompts, 'prompts.yaml') |
|
with open(yaml_file, 'r') as file: |
|
data = yaml.safe_load(file) |
|
return data |
|
except Exception as e: |
|
print(f"Execution filed : {e}") |
|
logger.error(msg="Execution filed", exc_info=e) |
|
|
|
def execution(self): |
|
try: |
|
data = self.__read_yaml() |
|
prompts = data["prompts"][ |
|
self.prompt_id] |
|
template = prompts["prompt_template"] |
|
prompt = PromptTemplate(template=template, input_variables=["entity"]) |
|
llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True) |
|
output = llm_chain.invoke(self.prompt_entity) |
|
return output["text"] |
|
except Exception as e: |
|
print(f"Execution filed : {e}") |
|
logger.critical(msg="Execution filed", exc_info=e) |
|
|
|
def __str__(self): |
|
return f"prompt_entity={self.prompt_entity}, prompt_id={self.prompt_id}" |
|
|
|
def __repr__(self): |
|
return f"{self.__class__.__name__}(prompt_entity: {type(self.prompt_entity)} = {self.prompt_entity}, prompt_id: {type(self.prompt_id)} = {self.prompt_id})" |
|
|
|
|
|
class HF_TinyLlama(HFInterface, ABC): |
|
def __init__(self, prompt_entity: str, prompt_id: int = 0): |
|
self.prompt_entity = prompt_entity |
|
self.prompt_id = prompt_id |
|
|
|
self.model_config = config["HF_TinyLlama"] |
|
|
|
self.llm = HuggingFaceEndpoint( |
|
repo_id=self.model_config["model"], |
|
model_kwargs={"temperature": self.model_config["temperature"], |
|
"max_new_tokens": self.model_config["max_new_tokens"], |
|
"top_k": self.model_config["top_k"], "load_in_8bit": self.model_config["load_in_8bit"]}) |
|
|
|
@staticmethod |
|
def __read_yaml(): |
|
try: |
|
yaml_file = os.path.join(path_to_prompts, 'prompts.yaml') |
|
with open(yaml_file, 'r') as file: |
|
data = yaml.safe_load(file) |
|
return data |
|
except Exception as e: |
|
print(f"Execution filed : {e}") |
|
logger.error(msg="Execution filed", exc_info=e) |
|
|
|
def execution(self): |
|
try: |
|
data = self.__read_yaml() |
|
prompts = data["prompts"][ |
|
self.prompt_id] |
|
template = prompts["prompt_template"] |
|
prompt = PromptTemplate(template=template, input_variables=["entity"]) |
|
llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True) |
|
output = llm_chain.invoke(self.prompt_entity) |
|
return output["text"] |
|
except Exception as e: |
|
print(f"Execution filed : {e}") |
|
logger.critical(msg="Execution filed", exc_info=e) |
|
|
|
def __str__(self): |
|
return f"prompt_entity={self.prompt_entity}, prompt_id={self.prompt_id}" |
|
|
|
def __repr__(self): |
|
return f"{self.__class__.__name__}(prompt_entity: {type(self.prompt_entity)} = {self.prompt_entity}, prompt_id: {type(self.prompt_id)} = {self.prompt_id})" |
|
|