File size: 4,769 Bytes
d5436e0 e805397 d5436e0 e805397 d5436e0 5cc25cc d5436e0 5cc25cc d5436e0 e805397 d5436e0 5cc25cc d5436e0 5cc25cc d5436e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import logging
import os
import yaml
from abc import ABC
from llm.llm_interface import LLMInterface
from llm.config import config
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms import HuggingFaceEndpoint
logger = logging.getLogger(__name__)
logger.setLevel(logging.CRITICAL) # because if something went wrong in execution application can't be work anymore
file_handler = logging.FileHandler(
"logs/chelsea_llm_huggingfacehub.log") # for all modules here template for logs file is "../logs/chelsea_{module_name}_{dir_name}.log"
logger.setLevel(logging.INFO) # informed
formatted = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatted)
logger.addHandler(file_handler)
logger.info("Getting information from hf_model module")
path_to_prompts = os.path.join(os.getcwd(), "/llm/")
print(path_to_prompts)
class HF_Mistaril(LLMInterface, ABC):
def __init__(self, prompt_entity: str, prompt_id: int = 0):
self.prompt_entity = prompt_entity
self.prompt_id = prompt_id
self.model_config = config["HF_Mistrail"]
self.llm = HuggingFaceEndpoint(
repo_id=self.model_config["model"],
model_kwargs={"temperature": self.model_config["temperature"],
"max_new_tokens": self.model_config["max_new_tokens"],
"top_k": self.model_config["top_k"], "load_in_8bit": self.model_config["load_in_8bit"]})
@staticmethod
def __read_yaml():
try:
yaml_file = os.path.join(path_to_prompts, 'prompts.yaml')
with open(yaml_file, 'r') as file:
data = yaml.safe_load(file)
return data
except Exception as e:
print(f"Execution filed : {e}")
logger.error(msg="Execution filed", exc_info=e)
def execution(self):
try:
data = self.__read_yaml()
prompts = data["prompts"][
self.prompt_id] #get second prompt from yaml, need change id parameter to get other prompt
template = prompts["prompt_template"]
prompt = PromptTemplate(template=template, input_variables=["entity"])
llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True)
output = llm_chain.invoke(self.prompt_entity)
return output["text"]
except Exception as e:
print(f"Execution filed : {e}")
logger.critical(msg="Execution filed", exc_info=e)
def __str__(self):
return f"prompt_entity={self.prompt_entity}, prompt_id={self.prompt_id}"
def __repr__(self):
return f"{self.__class__.__name__}(prompt_entity: {type(self.prompt_entity)} = {self.prompt_entity}, prompt_id: {type(self.prompt_id)} = {self.prompt_id})"
class HF_TinyLlama(LLMInterface, ABC):
def __init__(self, prompt_entity: str, prompt_id: int = 0):
self.prompt_entity = prompt_entity
self.prompt_id = prompt_id
self.model_config = config["HF_TinyLlama"]
self.llm = HuggingFaceEndpoint(
repo_id=self.model_config["model"],
model_kwargs={"temperature": self.model_config["temperature"],
"max_new_tokens": self.model_config["max_new_tokens"],
"top_k": self.model_config["top_k"], "load_in_8bit": self.model_config["load_in_8bit"]})
@staticmethod
def __read_yaml():
try:
yaml_file = os.path.join(path_to_prompts, 'prompts.yaml')
with open(yaml_file, 'r') as file:
data = yaml.safe_load(file)
return data
except Exception as e:
print(f"Execution filed : {e}")
logger.error(msg="Execution filed", exc_info=e)
def execution(self):
try:
data = self.__read_yaml()
prompts = data["prompts"][
self.prompt_id] #get second prompt from yaml, need change id parameter to get other prompt
template = prompts["prompt_template"]
prompt = PromptTemplate(template=template, input_variables=["entity"])
llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True)
output = llm_chain.invoke(self.prompt_entity)
return output["text"]
except Exception as e:
print(f"Execution filed : {e}")
logger.critical(msg="Execution filed", exc_info=e)
def __str__(self):
return f"prompt_entity={self.prompt_entity}, prompt_id={self.prompt_id}"
def __repr__(self):
return f"{self.__class__.__name__}(prompt_entity: {type(self.prompt_entity)} = {self.prompt_entity}, prompt_id: {type(self.prompt_id)} = {self.prompt_id})"
|