Spaces:
Sleeping
Sleeping
Update pipeline.py
Browse files- pipeline.py +45 -13
pipeline.py
CHANGED
@@ -10,6 +10,7 @@ from langchain.chains import RetrievalQA
|
|
10 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
11 |
import subprocess # Import subprocess to run shell commands
|
12 |
from langchain.llms.base import LLM # Import LLM
|
|
|
13 |
|
14 |
# Import the chain builders from our separate files
|
15 |
from classification_chain import get_classification_chain
|
@@ -22,6 +23,11 @@ if not os.environ.get("GEMINI_API_KEY"):
|
|
22 |
os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
|
23 |
if not os.environ.get("GROQ_API_KEY"):
|
24 |
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
# 2) Load spaCy model for NER and download the spaCy model if not already installed
|
27 |
def install_spacy_model():
|
@@ -67,7 +73,28 @@ def extract_main_topic(query: str) -> str:
|
|
67 |
# Return the extracted topic or a fallback value if no topic is found
|
68 |
return main_topic if main_topic else "this topic"
|
69 |
|
70 |
-
# 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
72 |
if os.path.exists(store_dir):
|
73 |
print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
|
@@ -96,7 +123,7 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
|
96 |
vectorstore.save_local(store_dir)
|
97 |
return vectorstore
|
98 |
|
99 |
-
#
|
100 |
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
101 |
class GeminiLangChainLLM(LLM):
|
102 |
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
|
@@ -117,13 +144,13 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
|
117 |
)
|
118 |
return rag_chain
|
119 |
|
120 |
-
#
|
121 |
classification_chain = get_classification_chain()
|
122 |
refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic
|
123 |
tailor_chain = get_tailor_chain()
|
124 |
cleaner_chain = get_cleaner_chain()
|
125 |
|
126 |
-
#
|
127 |
wellness_csv = "AIChatbot.csv"
|
128 |
brand_csv = "BrandAI.csv"
|
129 |
wellness_store_dir = "faiss_wellness_store"
|
@@ -136,7 +163,7 @@ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("
|
|
136 |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
|
137 |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
|
138 |
|
139 |
-
#
|
140 |
search_tool = DuckDuckGoSearchTool()
|
141 |
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
|
142 |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
|
@@ -148,19 +175,24 @@ def do_web_search(query: str) -> str:
|
|
148 |
response = manager_agent.run(search_query)
|
149 |
return response
|
150 |
|
151 |
-
#
|
152 |
def run_with_chain(query: str) -> str:
|
153 |
print("DEBUG: Starting run_with_chain...")
|
154 |
|
155 |
-
# 1)
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
157 |
classification = class_result.get("text", "").strip()
|
158 |
print("DEBUG: Classification =>", classification)
|
159 |
|
160 |
# If OutOfScope => refusal => tailor => return
|
161 |
if classification == "OutOfScope":
|
162 |
# Extract the main topic for the refusal message
|
163 |
-
topic = extract_main_topic(
|
164 |
print("DEBUG: Extracted Topic =>", topic)
|
165 |
|
166 |
# Pass the extracted topic to the refusal chain
|
@@ -170,14 +202,14 @@ def run_with_chain(query: str) -> str:
|
|
170 |
|
171 |
# If Wellness => wellness RAG => if insufficient => web => unify => tailor
|
172 |
if classification == "Wellness":
|
173 |
-
rag_result = wellness_rag_chain({"query":
|
174 |
csv_answer = rag_result["result"].strip()
|
175 |
if not csv_answer:
|
176 |
-
web_answer = do_web_search(
|
177 |
else:
|
178 |
lower_ans = csv_answer.lower()
|
179 |
if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
|
180 |
-
web_answer = do_web_search(
|
181 |
else:
|
182 |
web_answer = ""
|
183 |
final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
|
@@ -186,7 +218,7 @@ def run_with_chain(query: str) -> str:
|
|
186 |
|
187 |
# If Brand => brand RAG => tailor => return
|
188 |
if classification == "Brand":
|
189 |
-
rag_result = brand_rag_chain({"query":
|
190 |
csv_answer = rag_result["result"].strip()
|
191 |
final_merged = cleaner_chain.merge(kb=csv_answer, web="")
|
192 |
final_answer = tailor_chain.run({"response": final_merged})
|
|
|
10 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
11 |
import subprocess # Import subprocess to run shell commands
|
12 |
from langchain.llms.base import LLM # Import LLM
|
13 |
+
from mistralai import Mistral # Import Mistral for moderation
|
14 |
|
15 |
# Import the chain builders from our separate files
|
16 |
from classification_chain import get_classification_chain
|
|
|
23 |
os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
|
24 |
if not os.environ.get("GROQ_API_KEY"):
|
25 |
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
|
26 |
+
if not os.environ.get("MISTRAL_API_KEY"):
|
27 |
+
os.environ["MISTRAL_API_KEY"] = getpass.getpass("Enter your Mistral API Key: ")
|
28 |
+
|
29 |
+
# Initialize Mistral client
|
30 |
+
mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
|
31 |
|
32 |
# 2) Load spaCy model for NER and download the spaCy model if not already installed
|
33 |
def install_spacy_model():
|
|
|
73 |
# Return the extracted topic or a fallback value if no topic is found
|
74 |
return main_topic if main_topic else "this topic"
|
75 |
|
76 |
+
# 3) Function to moderate text using Mistral moderation API
|
77 |
+
def moderate_text(query: str) -> str:
|
78 |
+
"""
|
79 |
+
Classifies the query as harmful or not using Mistral Moderation API.
|
80 |
+
Returns "OutOfScope" if harmful, otherwise returns the original query.
|
81 |
+
"""
|
82 |
+
response = mistral_client.classifiers.moderate(
|
83 |
+
model="mistral-moderation-latest",
|
84 |
+
inputs=[query]
|
85 |
+
)
|
86 |
+
|
87 |
+
categories = response.results[0].categories
|
88 |
+
|
89 |
+
# Check if any harmful category is flagged
|
90 |
+
if categories.get("violence_and_threats", False) or \
|
91 |
+
categories.get("hate_and_discrimination", False) or \
|
92 |
+
categories.get("dangerous_and_criminal_content", False) or \
|
93 |
+
categories.get("selfharm", False):
|
94 |
+
return "OutOfScope"
|
95 |
+
return query
|
96 |
+
|
97 |
+
# 4) build_or_load_vectorstore (no changes)
|
98 |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
99 |
if os.path.exists(store_dir):
|
100 |
print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
|
|
|
123 |
vectorstore.save_local(store_dir)
|
124 |
return vectorstore
|
125 |
|
126 |
+
# 5) Build RAG chain for Gemini (no changes)
|
127 |
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
128 |
class GeminiLangChainLLM(LLM):
|
129 |
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
|
|
|
144 |
)
|
145 |
return rag_chain
|
146 |
|
147 |
+
# 6) Initialize all the separate chains
|
148 |
classification_chain = get_classification_chain()
|
149 |
refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic
|
150 |
tailor_chain = get_tailor_chain()
|
151 |
cleaner_chain = get_cleaner_chain()
|
152 |
|
153 |
+
# 7) Build our vectorstores + RAG chains
|
154 |
wellness_csv = "AIChatbot.csv"
|
155 |
brand_csv = "BrandAI.csv"
|
156 |
wellness_store_dir = "faiss_wellness_store"
|
|
|
163 |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
|
164 |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
|
165 |
|
166 |
+
# 8) Tools / Agents for web search (no changes)
|
167 |
search_tool = DuckDuckGoSearchTool()
|
168 |
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
|
169 |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
|
|
|
175 |
response = manager_agent.run(search_query)
|
176 |
return response
|
177 |
|
178 |
+
# 9) Orchestrator: run_with_chain
|
179 |
def run_with_chain(query: str) -> str:
|
180 |
print("DEBUG: Starting run_with_chain...")
|
181 |
|
182 |
+
# 1) Moderate the query for harmful content
|
183 |
+
moderated_query = moderate_text(query)
|
184 |
+
if moderated_query == "OutOfScope":
|
185 |
+
return "Sorry, this query contains harmful or inappropriate content."
|
186 |
+
|
187 |
+
# 2) Classify the query
|
188 |
+
class_result = classification_chain.invoke({"query": moderated_query})
|
189 |
classification = class_result.get("text", "").strip()
|
190 |
print("DEBUG: Classification =>", classification)
|
191 |
|
192 |
# If OutOfScope => refusal => tailor => return
|
193 |
if classification == "OutOfScope":
|
194 |
# Extract the main topic for the refusal message
|
195 |
+
topic = extract_main_topic(moderated_query)
|
196 |
print("DEBUG: Extracted Topic =>", topic)
|
197 |
|
198 |
# Pass the extracted topic to the refusal chain
|
|
|
202 |
|
203 |
# If Wellness => wellness RAG => if insufficient => web => unify => tailor
|
204 |
if classification == "Wellness":
|
205 |
+
rag_result = wellness_rag_chain({"query": moderated_query})
|
206 |
csv_answer = rag_result["result"].strip()
|
207 |
if not csv_answer:
|
208 |
+
web_answer = do_web_search(moderated_query)
|
209 |
else:
|
210 |
lower_ans = csv_answer.lower()
|
211 |
if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
|
212 |
+
web_answer = do_web_search(moderated_query)
|
213 |
else:
|
214 |
web_answer = ""
|
215 |
final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
|
|
|
218 |
|
219 |
# If Brand => brand RAG => tailor => return
|
220 |
if classification == "Brand":
|
221 |
+
rag_result = brand_rag_chain({"query": moderated_query})
|
222 |
csv_answer = rag_result["result"].strip()
|
223 |
final_merged = cleaner_chain.merge(kb=csv_answer, web="")
|
224 |
final_answer = tailor_chain.run({"response": final_merged})
|