Update llm/llm.py
Browse files- llm/llm.py +21 -6
llm/llm.py
CHANGED
@@ -1,8 +1,10 @@
|
|
|
|
|
|
|
|
1 |
from langchain.prompts import PromptTemplate
|
2 |
from langchain.chains import LLMChain
|
3 |
from langchain.llms import HuggingFaceHub
|
4 |
from .config import config
|
5 |
-
from .prompts import prompts
|
6 |
|
7 |
class LLM_chain:
|
8 |
def __init__(self):
|
@@ -10,9 +12,22 @@ class LLM_chain:
|
|
10 |
repo_id=config["model"],
|
11 |
model_kwargs={"temperature": config["temperature"], "max_new_tokens": config["max_new_tokens"], "top_k": config["top_k"], "load_in_8bit": config["load_in_8bit"]})
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def __call__(self, entity: str, id: int = 0):
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import logging
|
3 |
+
|
4 |
from langchain.prompts import PromptTemplate
|
5 |
from langchain.chains import LLMChain
|
6 |
from langchain.llms import HuggingFaceHub
|
7 |
from .config import config
|
|
|
8 |
|
9 |
class LLM_chain:
|
10 |
def __init__(self):
|
|
|
12 |
repo_id=config["model"],
|
13 |
model_kwargs={"temperature": config["temperature"], "max_new_tokens": config["max_new_tokens"], "top_k": config["top_k"], "load_in_8bit": config["load_in_8bit"]})
|
14 |
|
15 |
+
def __read_yaml(self):
|
16 |
+
try:
|
17 |
+
with open("./prompts.yaml", "r") as file:
|
18 |
+
data = yaml.safe_load(file)
|
19 |
+
return data
|
20 |
+
except Exception as e:
|
21 |
+
logging.error(e)
|
22 |
+
|
23 |
def __call__(self, entity: str, id: int = 0):
|
24 |
+
try:
|
25 |
+
data = self.__read_yaml()
|
26 |
+
prompts = data["prompts"]
|
27 |
+
template = prompts["prompt_template"][1]
|
28 |
+
prompt = PromptTemplate(template=template, input_variables=["entity"])
|
29 |
+
llm_chain = LLMChain(prompt=prompt, llm=self.llm, verbose=True)
|
30 |
+
output = llm_chain.invoke(entity)
|
31 |
+
return output["text"]
|
32 |
+
except Exception as e:
|
33 |
+
logging.error(e)
|