File size: 4,496 Bytes
d5436e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import logging
import os
import yaml

from abc import ABC

from llm.llm_interface import LLMInterface
from llm.config import config

from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms import HuggingFaceEndpoint

logger = logging.getLogger(__name__)

logger.setLevel(logging.CRITICAL)  # because if something went wrong in execution application can't be work anymore

file_handler = logging.FileHandler(
    "logs/chelsea_llm_huggingfacehub.log")  # for all modules here template for logs file is "../logs/chelsea_{module_name}_{dir_name}.log"
logger.setLevel(logging.INFO)  # informed

formatted = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatted)

logger.addHandler(file_handler)

logger.info("Getting information from hf_model module")


class HF_Mistaril(LLMInterface, ABC):
    def __init__(self, prompt_entity: str, prompt_id: int = 0):
        self.prompt_entity = prompt_entity
        self.prompt_id = prompt_id

        self.model_config = config["HF_Mistrail"]

        self.llm = HuggingFaceEndpoint(
            repo_id=self.model_config["model"],
            model_kwargs={"temperature": self.model_config["temperature"],
                          "max_new_tokens": self.model_config["max_new_tokens"],
                          "top_k": self.model_config["top_k"], "load_in_8bit": self.model_config["load_in_8bit"]})

    @staticmethod
    def __read_yaml():
        try:
            yaml_file = os.path.join("../", 'prompts.yaml')
            with open(yaml_file, 'r') as file:
                data = yaml.safe_load(file)
            return data
        except Exception as e:
            logger.error(msg="Execution filed", exc_info=e)

    def execution(self):
        try:
            data = self.__read_yaml()
            prompts = data["prompts"][
                self.prompt_id]  #get second prompt from yaml, need change id parameter to get other prompt
            template = prompts["prompt_template"]
            prompt = PromptTemplate(template=template, input_variables=["entity"])
            llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True)
            output = llm_chain.invoke(self.prompt_entity)
            return output["text"]
        except Exception as e:
            logger.critical(msg="Execution filed", exc_info=e)

    def __str__(self):
        return f"prompt_entity={self.prompt_entity}, prompt_id={self.prompt_id}"

    def __repr__(self):
        return f"{self.__class__.__name__}(prompt_entity: {type(self.prompt_entity)} = {self.prompt_entity}, prompt_id: {type(self.prompt_id)} = {self.prompt_id})"


class HF_TinyLlama(LLMInterface, ABC):
    def __init__(self, prompt_entity: str, prompt_id: int = 0):
        self.prompt_entity = prompt_entity
        self.prompt_id = prompt_id

        self.model_config = config["HF_TinyLlama"]

        self.llm = HuggingFaceEndpoint(
            repo_id=self.model_config["model"],
            model_kwargs={"temperature": self.model_config["temperature"],
                          "max_new_tokens": self.model_config["max_new_tokens"],
                          "top_k": self.model_config["top_k"], "load_in_8bit": self.model_config["load_in_8bit"]})

    @staticmethod
    def __read_yaml():
        try:
            yaml_file = os.path.join("../", 'prompts.yaml')
            with open(yaml_file, 'r') as file:
                data = yaml.safe_load(file)
            return data
        except Exception as e:
            logger.error(msg="Execution filed", exc_info=e)

    def execution(self):
        try:
            data = self.__read_yaml()
            prompts = data["prompts"][
                self.prompt_id]  #get second prompt from yaml, need change id parameter to get other prompt
            template = prompts["prompt_template"]
            prompt = PromptTemplate(template=template, input_variables=["entity"])
            llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True)
            output = llm_chain.invoke(self.prompt_entity)
            return output["text"]
        except Exception as e:
            logger.critical(msg="Execution filed", exc_info=e)

    def __str__(self):
        return f"prompt_entity={self.prompt_entity}, prompt_id={self.prompt_id}"

    def __repr__(self):
        return f"{self.__class__.__name__}(prompt_entity: {type(self.prompt_entity)} = {self.prompt_entity}, prompt_id: {type(self.prompt_id)} = {self.prompt_id})"