Phoenix21 commited on
Commit
e27c8c7
·
verified ·
1 Parent(s): 5144d12

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +45 -13
pipeline.py CHANGED
@@ -10,6 +10,7 @@ from langchain.chains import RetrievalQA
10
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
11
  import subprocess # Import subprocess to run shell commands
12
  from langchain.llms.base import LLM # Import LLM
 
13
 
14
  # Import the chain builders from our separate files
15
  from classification_chain import get_classification_chain
@@ -22,6 +23,11 @@ if not os.environ.get("GEMINI_API_KEY"):
22
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
23
  if not os.environ.get("GROQ_API_KEY"):
24
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
 
 
 
 
 
25
 
26
  # 2) Load spaCy model for NER and download the spaCy model if not already installed
27
  def install_spacy_model():
@@ -67,7 +73,28 @@ def extract_main_topic(query: str) -> str:
67
  # Return the extracted topic or a fallback value if no topic is found
68
  return main_topic if main_topic else "this topic"
69
 
70
- # 3) build_or_load_vectorstore (no changes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
72
  if os.path.exists(store_dir):
73
  print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
@@ -96,7 +123,7 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
96
  vectorstore.save_local(store_dir)
97
  return vectorstore
98
 
99
- # 4) Build RAG chain for Gemini (no changes)
100
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
101
  class GeminiLangChainLLM(LLM):
102
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
@@ -117,13 +144,13 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
117
  )
118
  return rag_chain
119
 
120
- # 5) Initialize all the separate chains
121
  classification_chain = get_classification_chain()
122
  refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic
123
  tailor_chain = get_tailor_chain()
124
  cleaner_chain = get_cleaner_chain()
125
 
126
- # 6) Build our vectorstores + RAG chains
127
  wellness_csv = "AIChatbot.csv"
128
  brand_csv = "BrandAI.csv"
129
  wellness_store_dir = "faiss_wellness_store"
@@ -136,7 +163,7 @@ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("
136
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
137
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
138
 
139
- # 7) Tools / Agents for web search (no changes)
140
  search_tool = DuckDuckGoSearchTool()
141
  web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
142
  managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
@@ -148,19 +175,24 @@ def do_web_search(query: str) -> str:
148
  response = manager_agent.run(search_query)
149
  return response
150
 
151
- # 8) Orchestrator: run_with_chain
152
  def run_with_chain(query: str) -> str:
153
  print("DEBUG: Starting run_with_chain...")
154
 
155
- # 1) Classify the query
156
- class_result = classification_chain.invoke({"query": query})
 
 
 
 
 
157
  classification = class_result.get("text", "").strip()
158
  print("DEBUG: Classification =>", classification)
159
 
160
  # If OutOfScope => refusal => tailor => return
161
  if classification == "OutOfScope":
162
  # Extract the main topic for the refusal message
163
- topic = extract_main_topic(query)
164
  print("DEBUG: Extracted Topic =>", topic)
165
 
166
  # Pass the extracted topic to the refusal chain
@@ -170,14 +202,14 @@ def run_with_chain(query: str) -> str:
170
 
171
  # If Wellness => wellness RAG => if insufficient => web => unify => tailor
172
  if classification == "Wellness":
173
- rag_result = wellness_rag_chain({"query": query})
174
  csv_answer = rag_result["result"].strip()
175
  if not csv_answer:
176
- web_answer = do_web_search(query)
177
  else:
178
  lower_ans = csv_answer.lower()
179
  if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
180
- web_answer = do_web_search(query)
181
  else:
182
  web_answer = ""
183
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
@@ -186,7 +218,7 @@ def run_with_chain(query: str) -> str:
186
 
187
  # If Brand => brand RAG => tailor => return
188
  if classification == "Brand":
189
- rag_result = brand_rag_chain({"query": query})
190
  csv_answer = rag_result["result"].strip()
191
  final_merged = cleaner_chain.merge(kb=csv_answer, web="")
192
  final_answer = tailor_chain.run({"response": final_merged})
 
10
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
11
  import subprocess # Import subprocess to run shell commands
12
  from langchain.llms.base import LLM # Import LLM
13
+ from mistralai import Mistral # Import Mistral for moderation
14
 
15
  # Import the chain builders from our separate files
16
  from classification_chain import get_classification_chain
 
23
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
24
  if not os.environ.get("GROQ_API_KEY"):
25
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
26
+ if not os.environ.get("MISTRAL_API_KEY"):
27
+ os.environ["MISTRAL_API_KEY"] = getpass.getpass("Enter your Mistral API Key: ")
28
+
29
+ # Initialize Mistral client
30
+ mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
31
 
32
  # 2) Load spaCy model for NER and download the spaCy model if not already installed
33
  def install_spacy_model():
 
73
  # Return the extracted topic or a fallback value if no topic is found
74
  return main_topic if main_topic else "this topic"
75
 
76
+ # 3) Function to moderate text using Mistral moderation API
77
+ def moderate_text(query: str) -> str:
78
+ """
79
+ Classifies the query as harmful or not using Mistral Moderation API.
80
+ Returns "OutOfScope" if harmful, otherwise returns the original query.
81
+ """
82
+ response = mistral_client.classifiers.moderate(
83
+ model="mistral-moderation-latest",
84
+ inputs=[query]
85
+ )
86
+
87
+ categories = response.results[0].categories
88
+
89
+ # Check if any harmful category is flagged
90
+ if categories.get("violence_and_threats", False) or \
91
+ categories.get("hate_and_discrimination", False) or \
92
+ categories.get("dangerous_and_criminal_content", False) or \
93
+ categories.get("selfharm", False):
94
+ return "OutOfScope"
95
+ return query
96
+
97
+ # 4) build_or_load_vectorstore (no changes)
98
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
99
  if os.path.exists(store_dir):
100
  print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
 
123
  vectorstore.save_local(store_dir)
124
  return vectorstore
125
 
126
+ # 5) Build RAG chain for Gemini (no changes)
127
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
128
  class GeminiLangChainLLM(LLM):
129
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
 
144
  )
145
  return rag_chain
146
 
147
+ # 6) Initialize all the separate chains
148
  classification_chain = get_classification_chain()
149
  refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic
150
  tailor_chain = get_tailor_chain()
151
  cleaner_chain = get_cleaner_chain()
152
 
153
+ # 7) Build our vectorstores + RAG chains
154
  wellness_csv = "AIChatbot.csv"
155
  brand_csv = "BrandAI.csv"
156
  wellness_store_dir = "faiss_wellness_store"
 
163
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
164
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
165
 
166
+ # 8) Tools / Agents for web search (no changes)
167
  search_tool = DuckDuckGoSearchTool()
168
  web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
169
  managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
 
175
  response = manager_agent.run(search_query)
176
  return response
177
 
178
+ # 9) Orchestrator: run_with_chain
179
  def run_with_chain(query: str) -> str:
180
  print("DEBUG: Starting run_with_chain...")
181
 
182
+ # 1) Moderate the query for harmful content
183
+ moderated_query = moderate_text(query)
184
+ if moderated_query == "OutOfScope":
185
+ return "Sorry, this query contains harmful or inappropriate content."
186
+
187
+ # 2) Classify the query
188
+ class_result = classification_chain.invoke({"query": moderated_query})
189
  classification = class_result.get("text", "").strip()
190
  print("DEBUG: Classification =>", classification)
191
 
192
  # If OutOfScope => refusal => tailor => return
193
  if classification == "OutOfScope":
194
  # Extract the main topic for the refusal message
195
+ topic = extract_main_topic(moderated_query)
196
  print("DEBUG: Extracted Topic =>", topic)
197
 
198
  # Pass the extracted topic to the refusal chain
 
202
 
203
  # If Wellness => wellness RAG => if insufficient => web => unify => tailor
204
  if classification == "Wellness":
205
+ rag_result = wellness_rag_chain({"query": moderated_query})
206
  csv_answer = rag_result["result"].strip()
207
  if not csv_answer:
208
+ web_answer = do_web_search(moderated_query)
209
  else:
210
  lower_ans = csv_answer.lower()
211
  if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
212
+ web_answer = do_web_search(moderated_query)
213
  else:
214
  web_answer = ""
215
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
 
218
 
219
  # If Brand => brand RAG => tailor => return
220
  if classification == "Brand":
221
+ rag_result = brand_rag_chain({"query": moderated_query})
222
  csv_answer = rag_result["result"].strip()
223
  final_merged = cleaner_chain.merge(kb=csv_answer, web="")
224
  final_answer = tailor_chain.run({"response": final_merged})