Phoenix21 commited on
Commit
19fdb92
·
verified ·
1 Parent(s): c947e4c

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +8 -85
pipeline.py CHANGED
@@ -40,28 +40,11 @@ class QueryInput(BaseModel):
40
  raise ValueError("Query cannot be empty or just whitespace")
41
  return v.strip()
42
 
43
- class ClassificationResult(BaseModel):
44
- category: str = Field(..., description="The classification category")
45
- confidence: float = Field(..., ge=0.0, le=1.0, description="Classification confidence score")
46
-
47
  class ModerationResult(BaseModel):
48
  is_safe: bool = Field(..., description="Whether the content is safe")
49
  categories: Dict[str, bool] = Field(default_factory=dict, description="Detected content categories")
50
  original_text: str = Field(..., description="The original input text")
51
 
52
- class RAGResponse(BaseModel):
53
- answer: str = Field(..., description="The generated answer")
54
- sources: List[str] = Field(default_factory=list, description="Source documents used")
55
- confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score of the answer")
56
-
57
- def sanitize_message(message: Any) -> str:
58
- """Sanitize message input to ensure it's a valid string."""
59
- if hasattr(message, 'content'):
60
- return str(message.content)
61
- if isinstance(message, (list, dict)):
62
- return str(message)
63
- return str(message)
64
-
65
  # Load spaCy model for NER
66
  def install_spacy_model():
67
  try:
@@ -128,27 +111,18 @@ def moderate_text(query: str) -> ModerationResult:
128
  except Exception as e:
129
  raise RuntimeError(f"Moderation failed: {str(e)}")
130
 
131
- def classify_query(query: str) -> ClassificationResult:
132
  try:
133
  query_input = QueryInput(query=query)
134
 
135
  wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
136
  if any(keyword in query_input.query.lower() for keyword in wellness_keywords):
137
- return ClassificationResult(category="Wellness", confidence=0.9)
138
 
139
  class_result = classification_chain.invoke({"query": query_input.query})
140
  classification = class_result.get("text", "").strip()
141
 
142
- confidence_map = {
143
- "Wellness": 0.8,
144
- "Brand": 0.8,
145
- "OutOfScope": 0.6
146
- }
147
-
148
- return ClassificationResult(
149
- category=classification if classification != "" else "OutOfScope",
150
- confidence=confidence_map.get(classification, 0.5)
151
- )
152
  except ValidationError as e:
153
  raise ValueError(f"Classification input validation failed: {str(e)}")
154
  except Exception as e:
@@ -166,14 +140,6 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
166
  df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
167
  df.columns = df.columns.str.strip()
168
 
169
- if "Answer" in df.columns:
170
- df.rename(columns={"Answer": "Answers"}, inplace=True)
171
- if "Question" not in df.columns and "Question " in df.columns:
172
- df.rename(columns={"Question ": "Question"}, inplace=True)
173
-
174
- if "Question" not in df.columns or "Answers" not in df.columns:
175
- raise ValueError("CSV must have 'Question' and 'Answers' columns")
176
-
177
  docs = [
178
  Document(page_content=str(row["Answers"]), metadata={"question": str(row["Question"])})
179
  for _, row in df.iterrows()
@@ -209,98 +175,55 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
209
  except Exception as e:
210
  raise RuntimeError(f"Error building RAG chain: {str(e)}")
211
 
212
- def do_web_search(query: str) -> str:
213
- try:
214
- query_input = QueryInput(query=query)
215
- search_tool = DuckDuckGoSearchTool()
216
- web_agent = CodeAgent(tools=[search_tool], model=pydantic_agent)
217
- managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Performs web searches")
218
- manager_agent = CodeAgent(tools=[], model=pydantic_agent, managed_agents=[managed_web_agent])
219
-
220
- search_query = f"Give me relevant info: {query_input.query}"
221
- return manager_agent.run(search_query)
222
- except Exception as e:
223
- return f"Web search failed: {str(e)}"
224
-
225
- def merge_responses(kb_answer: str, web_answer: str) -> str:
226
- try:
227
- if not kb_answer and not web_answer:
228
- return "No relevant information found."
229
-
230
- if not web_answer:
231
- return kb_answer.strip()
232
-
233
- if not kb_answer:
234
- return web_answer.strip()
235
-
236
- return f"Knowledge Base Answer: {kb_answer.strip()}\n\nWeb Search Result: {web_answer.strip()}"
237
- except Exception as e:
238
- return f"Error merging responses: {str(e)}"
239
-
240
  def run_pipeline(query: str) -> str:
241
  try:
242
- # Sanitize and validate input
243
  query = sanitize_message(query)
244
 
245
- # Moderate content
246
  moderation_result = moderate_text(query)
247
  if not moderation_result.is_safe:
248
  return "Sorry, this query contains harmful or inappropriate content."
249
 
250
- # Classify the query
251
- classification_result = classify_query(moderation_result.original_text)
252
 
253
- # Handle different classifications
254
- if classification_result.category == "OutOfScope":
255
  refusal_text = refusal_chain.run({"topic": "this topic"})
256
  return tailor_chain.run({"response": refusal_text}).strip()
257
 
258
- if classification_result.category == "Wellness":
259
  rag_result = wellness_rag_chain({"query": moderation_result.original_text})
260
  csv_answer = rag_result["result"].strip()
261
  web_answer = "" if csv_answer else do_web_search(moderation_result.original_text)
262
  final_merged = merge_responses(csv_answer, web_answer)
263
  return tailor_chain.run({"response": final_merged}).strip()
264
 
265
- if classification_result.category == "Brand":
266
  rag_result = brand_rag_chain({"query": moderation_result.original_text})
267
  csv_answer = rag_result["result"].strip()
268
  final_merged = merge_responses(csv_answer, "")
269
  return tailor_chain.run({"response": final_merged}).strip()
270
 
271
- # Default fallback
272
  refusal_text = refusal_chain.run({"topic": "this topic"})
273
  return tailor_chain.run({"response": refusal_text}).strip()
274
 
275
- except Exception as e:
276
- return f"An error occurred while processing your request: {str(e)}"
277
-
278
  # Initialize chains and vectorstores
279
  try:
280
- # Initialize chain components
281
  classification_chain = get_classification_chain()
282
  refusal_chain = get_refusal_chain()
283
  tailor_chain = get_tailor_chain()
284
  cleaner_chain = get_cleaner_chain()
285
 
286
- # Set up paths
287
  wellness_csv = "AIChatbot.csv"
288
  brand_csv = "BrandAI.csv"
289
  wellness_store_dir = "faiss_wellness_store"
290
  brand_store_dir = "faiss_brand_store"
291
 
292
- # Build or load vectorstores
293
  wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
294
  brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
295
 
296
- # Initialize LLM and RAG chains
297
  gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
298
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
299
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
300
-
301
  print("Pipeline initialized successfully!")
302
  except Exception as e:
303
  print(f"Error initializing pipeline: {str(e)}")
304
-
305
- def run_with_chain(query: str) -> str:
306
- return run_pipeline(query)
 
40
  raise ValueError("Query cannot be empty or just whitespace")
41
  return v.strip()
42
 
 
 
 
 
43
  class ModerationResult(BaseModel):
44
  is_safe: bool = Field(..., description="Whether the content is safe")
45
  categories: Dict[str, bool] = Field(default_factory=dict, description="Detected content categories")
46
  original_text: str = Field(..., description="The original input text")
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # Load spaCy model for NER
49
  def install_spacy_model():
50
  try:
 
111
  except Exception as e:
112
  raise RuntimeError(f"Moderation failed: {str(e)}")
113
 
114
+ def classify_query(query: str) -> str:
115
  try:
116
  query_input = QueryInput(query=query)
117
 
118
  wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
119
  if any(keyword in query_input.query.lower() for keyword in wellness_keywords):
120
+ return "Wellness"
121
 
122
  class_result = classification_chain.invoke({"query": query_input.query})
123
  classification = class_result.get("text", "").strip()
124
 
125
+ return classification if classification != "" else "OutOfScope"
 
 
 
 
 
 
 
 
 
126
  except ValidationError as e:
127
  raise ValueError(f"Classification input validation failed: {str(e)}")
128
  except Exception as e:
 
140
  df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
141
  df.columns = df.columns.str.strip()
142
 
 
 
 
 
 
 
 
 
143
  docs = [
144
  Document(page_content=str(row["Answers"]), metadata={"question": str(row["Question"])})
145
  for _, row in df.iterrows()
 
175
  except Exception as e:
176
  raise RuntimeError(f"Error building RAG chain: {str(e)}")
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def run_pipeline(query: str) -> str:
179
  try:
 
180
  query = sanitize_message(query)
181
 
 
182
  moderation_result = moderate_text(query)
183
  if not moderation_result.is_safe:
184
  return "Sorry, this query contains harmful or inappropriate content."
185
 
186
+ classification = classify_query(moderation_result.original_text)
 
187
 
188
+ if classification == "OutOfScope":
 
189
  refusal_text = refusal_chain.run({"topic": "this topic"})
190
  return tailor_chain.run({"response": refusal_text}).strip()
191
 
192
+ if classification == "Wellness":
193
  rag_result = wellness_rag_chain({"query": moderation_result.original_text})
194
  csv_answer = rag_result["result"].strip()
195
  web_answer = "" if csv_answer else do_web_search(moderation_result.original_text)
196
  final_merged = merge_responses(csv_answer, web_answer)
197
  return tailor_chain.run({"response": final_merged}).strip()
198
 
199
+ if classification == "Brand":
200
  rag_result = brand_rag_chain({"query": moderation_result.original_text})
201
  csv_answer = rag_result["result"].strip()
202
  final_merged = merge_responses(csv_answer, "")
203
  return tailor_chain.run({"response": final_merged}).strip()
204
 
 
205
  refusal_text = refusal_chain.run({"topic": "this topic"})
206
  return tailor_chain.run({"response": refusal_text}).strip()
207
 
 
 
 
208
  # Initialize chains and vectorstores
209
  try:
 
210
  classification_chain = get_classification_chain()
211
  refusal_chain = get_refusal_chain()
212
  tailor_chain = get_tailor_chain()
213
  cleaner_chain = get_cleaner_chain()
214
 
 
215
  wellness_csv = "AIChatbot.csv"
216
  brand_csv = "BrandAI.csv"
217
  wellness_store_dir = "faiss_wellness_store"
218
  brand_store_dir = "faiss_brand_store"
219
 
 
220
  wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
221
  brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
222
 
 
223
  gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
224
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
225
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
226
+
227
  print("Pipeline initialized successfully!")
228
  except Exception as e:
229
  print(f"Error initializing pipeline: {str(e)}")