Phoenix21 commited on
Commit
1eb0002
·
verified ·
1 Parent(s): 0b20500

Added pydantic error handling

Browse files
Files changed (1) hide show
  1. pipeline.py +214 -150
pipeline.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import getpass
3
  import spacy
4
  import pandas as pd
5
- from typing import Optional
6
  import subprocess
7
  from langchain.llms.base import LLM
8
  from langchain.docstore.document import Document
@@ -10,7 +10,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
10
  from langchain.vectorstores import FAISS
11
  from langchain.chains import RetrievalQA
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
- from pydantic import BaseModel, ValidationError, validator
14
  from mistralai import Mistral
15
  from langchain.prompts import PromptTemplate
16
 
@@ -25,7 +25,33 @@ from prompts import classification_prompt, refusal_prompt, tailor_prompt
25
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
26
  client = Mistral(api_key=mistral_api_key)
27
 
28
- # Load spaCy model for NER and download it if not already installed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def install_spacy_model():
30
  try:
31
  spacy.load("en_core_web_sm")
@@ -38,99 +64,121 @@ def install_spacy_model():
38
  install_spacy_model()
39
  nlp = spacy.load("en_core_web_sm")
40
 
41
- # Function to extract the main topic from the query using spaCy NER
42
  def extract_main_topic(query: str) -> str:
43
- doc = nlp(query)
44
- main_topic = None
45
- for ent in doc.ents:
46
- if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]:
47
- main_topic = ent.text
48
- break
49
- if not main_topic:
50
- for token in doc:
51
- if token.pos_ in ["NOUN", "PROPN"]:
52
- main_topic = token.text
53
  break
54
- return main_topic if main_topic else "this topic"
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # Pydantic model to handle string input validation
57
- class QueryInput(BaseModel):
58
- query: str
59
-
60
- # Validator to ensure the query is always a string
61
- @validator('query')
62
- def check_query_is_string(cls, v):
63
- if not isinstance(v, str):
64
- raise ValueError("Query must be a valid string.")
65
- return v
66
-
67
- # Function to classify query based on wellness topics
68
- def classify_query(query: str) -> str:
69
- wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
70
- if any(keyword in query.lower() for keyword in wellness_keywords):
71
- return "Wellness"
72
- # Fallback to classification chain if not directly recognized
73
- class_result = classification_chain.invoke({"query": query})
74
- classification = class_result.get("text", "").strip()
75
- return classification if classification != "OutOfScope" else "OutOfScope"
76
-
77
- # Function to moderate text using Mistral moderation API (sync version)
78
- def moderate_text(query: str) -> str:
79
  try:
80
- # Use Pydantic to validate text input
81
- query_input = QueryInput(query=query) # This will validate that the query is a string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  except ValidationError as e:
83
- print(f"Error validating text: {e}")
84
- return "Invalid text format."
85
-
86
- # Call the Mistral moderation API
87
- response = client.classifiers.moderate_chat(
88
- model="mistral-moderation-latest",
89
- inputs=[{"role": "user", "content": query}]
90
- )
91
-
92
- # Check if harmful categories are present in the response
93
- if hasattr(response, 'results') and response.results:
94
- categories = response.results[0].categories
95
- if categories.get("violence_and_threats", False) or \
96
- categories.get("hate_and_discrimination", False) or \
97
- categories.get("dangerous_and_criminal_content", False) or \
98
- categories.get("selfharm", False):
99
- return "OutOfScope"
100
-
101
- return query
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- # Function to build or load the vector store from CSV data
105
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
106
- if os.path.exists(store_dir):
107
- print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
108
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
109
- vectorstore = FAISS.load_local(store_dir, embeddings)
110
- return vectorstore
111
- else:
112
- print(f"DEBUG: Building new store from CSV: {csv_path}")
113
  df = pd.read_csv(csv_path)
114
  df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
115
  df.columns = df.columns.str.strip()
 
 
116
  if "Answer" in df.columns:
117
  df.rename(columns={"Answer": "Answers"}, inplace=True)
118
  if "Question" not in df.columns and "Question " in df.columns:
119
  df.rename(columns={"Question ": "Question"}, inplace=True)
 
120
  if "Question" not in df.columns or "Answers" not in df.columns:
121
- raise ValueError("CSV must have 'Question' and 'Answers' columns.")
122
- docs = []
123
- for _, row in df.iterrows():
124
- q = str(row["Question"])
125
- ans = str(row["Answers"])
126
- doc = Document(page_content=ans, metadata={"question": q})
127
- docs.append(doc)
128
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
129
  vectorstore = FAISS.from_documents(docs, embedding=embeddings)
130
  vectorstore.save_local(store_dir)
131
  return vectorstore
 
 
 
132
 
133
- # Function to build RAG chain
134
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
135
  class GeminiLangChainLLM(LLM):
136
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
@@ -141,87 +189,103 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
141
  def _llm_type(self) -> str:
142
  return "custom_gemini"
143
 
144
- retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
145
- gemini_as_llm = GeminiLangChainLLM()
146
- rag_chain = RetrievalQA.from_chain_type(
147
- llm=gemini_as_llm,
148
- chain_type="stuff",
149
- retriever=retriever,
150
- return_source_documents=True
151
- )
152
- return rag_chain
153
-
154
- # Function to perform web search using DuckDuckGo
 
155
  def do_web_search(query: str) -> str:
156
- search_tool = DuckDuckGoSearchTool()
157
- web_agent = CodeAgent(tools=[search_tool], model=pydantic_agent)
158
- managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
159
- manager_agent = CodeAgent(tools=[], model=pydantic_agent, managed_agents=[managed_web_agent])
160
-
161
- search_query = f"Give me relevant info: {query}"
162
- response = manager_agent.run(search_query)
163
- return response
 
 
 
164
 
165
- # Function to combine web and knowledge base responses
166
  def merge_responses(kb_answer: str, web_answer: str) -> str:
167
- # Merge both answers with a cohesive response
168
- final_answer = f"Knowledge Base Answer: {kb_answer}\n\nWeb Search Result: {web_answer}"
169
- return final_answer.strip()
 
 
 
 
 
 
 
 
 
 
170
 
171
- # Orchestrate the entire workflow
172
  def run_pipeline(query: str) -> str:
173
- # Moderate the query for harmful content
174
- moderated_query = moderate_text(query)
175
- if moderated_query == "OutOfScope":
176
- return "Sorry, this query contains harmful or inappropriate content."
 
 
 
 
177
 
178
- # Classify the query manually
179
- classification = classify_query(moderated_query)
 
180
 
181
- if classification == "OutOfScope":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  refusal_text = refusal_chain.run({"topic": "this topic"})
183
- final_refusal = tailor_chain.run({"response": refusal_text})
184
- return final_refusal.strip()
185
-
186
- if classification == "Wellness":
187
- rag_result = wellness_rag_chain({"query": moderated_query})
188
- csv_answer = rag_result["result"].strip()
189
- web_answer = "" # Empty if we found an answer from the knowledge base
190
- if not csv_answer:
191
- web_answer = do_web_search(moderated_query)
192
- final_merged = merge_responses(csv_answer, web_answer)
193
- final_answer = tailor_chain.run({"response": final_merged})
194
- return final_answer.strip()
195
-
196
- if classification == "Brand":
197
- rag_result = brand_rag_chain({"query": moderated_query})
198
- csv_answer = rag_result["result"].strip()
199
- final_merged = merge_responses(csv_answer, "")
200
- final_answer = tailor_chain.run({"response": final_merged})
201
- return final_answer.strip()
202
-
203
- refusal_text = refusal_chain.run({"topic": "this topic"})
204
- final_refusal = tailor_chain.run({"response": refusal_text})
205
- return final_refusal.strip()
206
-
207
- # Initialize chains
208
- classification_chain = get_classification_chain()
209
- refusal_chain = get_refusal_chain()
210
- tailor_chain = get_tailor_chain()
211
- cleaner_chain = get_cleaner_chain()
212
-
213
- wellness_csv = "AIChatbot.csv"
214
- brand_csv = "BrandAI.csv"
215
- wellness_store_dir = "faiss_wellness_store"
216
- brand_store_dir = "faiss_brand_store"
217
-
218
- wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
219
- brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
220
-
221
- gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
222
- wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
223
- brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
224
-
225
- # Function to wrap up and run the chain
226
  def run_with_chain(query: str) -> str:
227
- return run_pipeline(query)
 
2
  import getpass
3
  import spacy
4
  import pandas as pd
5
+ from typing import Optional, List, Dict, Any
6
  import subprocess
7
  from langchain.llms.base import LLM
8
  from langchain.docstore.document import Document
 
10
  from langchain.vectorstores import FAISS
11
  from langchain.chains import RetrievalQA
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
+ from pydantic import BaseModel, Field, ValidationError, validator
14
  from mistralai import Mistral
15
  from langchain.prompts import PromptTemplate
16
 
 
25
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
26
  client = Mistral(api_key=mistral_api_key)
27
 
28
+ # Pydantic models for validation and type safety
29
+ class QueryInput(BaseModel):
30
+ query: str = Field(..., min_length=1, description="The input query string")
31
+
32
+ @validator('query')
33
+ def check_query_is_string(cls, v):
34
+ if not isinstance(v, str):
35
+ raise ValueError("Query must be a valid string")
36
+ if v.strip() == "":
37
+ raise ValueError("Query cannot be empty or just whitespace")
38
+ return v.strip()
39
+
40
+ class ClassificationResult(BaseModel):
41
+ category: str = Field(..., description="The classification category")
42
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Classification confidence score")
43
+
44
+ class ModerationResult(BaseModel):
45
+ is_safe: bool = Field(..., description="Whether the content is safe")
46
+ categories: Dict[str, bool] = Field(default_factory=dict, description="Detected content categories")
47
+ original_text: str = Field(..., description="The original input text")
48
+
49
+ class RAGResponse(BaseModel):
50
+ answer: str = Field(..., description="The generated answer")
51
+ sources: List[str] = Field(default_factory=list, description="Source documents used")
52
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score of the answer")
53
+
54
+ # Load spaCy model for NER
55
  def install_spacy_model():
56
  try:
57
  spacy.load("en_core_web_sm")
 
64
  install_spacy_model()
65
  nlp = spacy.load("en_core_web_sm")
66
 
 
67
  def extract_main_topic(query: str) -> str:
68
+ try:
69
+ query_input = QueryInput(query=query)
70
+ doc = nlp(query_input.query)
71
+ main_topic = None
72
+
73
+ # Try to find named entities first
74
+ for ent in doc.ents:
75
+ if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]:
76
+ main_topic = ent.text
 
77
  break
78
+
79
+ # If no named entities found, look for nouns
80
+ if not main_topic:
81
+ for token in doc:
82
+ if token.pos_ in ["NOUN", "PROPN"]:
83
+ main_topic = token.text
84
+ break
85
+
86
+ return main_topic if main_topic else "this topic"
87
+ except Exception as e:
88
+ print(f"Error extracting main topic: {e}")
89
+ return "this topic"
90
 
91
+ def moderate_text(query: str) -> ModerationResult:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  try:
93
+ query_input = QueryInput(query=query)
94
+
95
+ response = client.classifiers.moderate_chat(
96
+ model="mistral-moderation-latest",
97
+ inputs=[{"role": "user", "content": query_input.query}]
98
+ )
99
+
100
+ is_safe = True
101
+ categories = {}
102
+
103
+ if hasattr(response, 'results') and response.results:
104
+ categories = {
105
+ "violence": response.results[0].categories.get("violence_and_threats", False),
106
+ "hate": response.results[0].categories.get("hate_and_discrimination", False),
107
+ "dangerous": response.results[0].categories.get("dangerous_and_criminal_content", False),
108
+ "selfharm": response.results[0].categories.get("selfharm", False)
109
+ }
110
+ is_safe = not any(categories.values())
111
+
112
+ return ModerationResult(
113
+ is_safe=is_safe,
114
+ categories=categories,
115
+ original_text=query_input.query
116
+ )
117
  except ValidationError as e:
118
+ raise ValueError(f"Input validation failed: {str(e)}")
119
+ except Exception as e:
120
+ raise RuntimeError(f"Moderation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ def classify_query(query: str) -> ClassificationResult:
123
+ try:
124
+ query_input = QueryInput(query=query)
125
+
126
+ wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
127
+ if any(keyword in query_input.query.lower() for keyword in wellness_keywords):
128
+ return ClassificationResult(category="Wellness", confidence=0.9)
129
+
130
+ class_result = classification_chain.invoke({"query": query_input.query})
131
+ classification = class_result.get("text", "").strip()
132
+
133
+ confidence_map = {
134
+ "Wellness": 0.8,
135
+ "Brand": 0.8,
136
+ "OutOfScope": 0.6
137
+ }
138
+
139
+ return ClassificationResult(
140
+ category=classification if classification != "" else "OutOfScope",
141
+ confidence=confidence_map.get(classification, 0.5)
142
+ )
143
+ except ValidationError as e:
144
+ raise ValueError(f"Classification input validation failed: {str(e)}")
145
+ except Exception as e:
146
+ raise RuntimeError(f"Classification failed: {str(e)}")
147
 
 
148
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
149
+ try:
150
+ if os.path.exists(store_dir):
151
+ print(f"Loading existing FAISS store from '{store_dir}'")
152
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
153
+ return FAISS.load_local(store_dir, embeddings)
154
+
155
+ print(f"Building new FAISS store from CSV: {csv_path}")
156
  df = pd.read_csv(csv_path)
157
  df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
158
  df.columns = df.columns.str.strip()
159
+
160
+ # Handle column name variations
161
  if "Answer" in df.columns:
162
  df.rename(columns={"Answer": "Answers"}, inplace=True)
163
  if "Question" not in df.columns and "Question " in df.columns:
164
  df.rename(columns={"Question ": "Question"}, inplace=True)
165
+
166
  if "Question" not in df.columns or "Answers" not in df.columns:
167
+ raise ValueError("CSV must have 'Question' and 'Answers' columns")
168
+
169
+ docs = [
170
+ Document(page_content=str(row["Answers"]), metadata={"question": str(row["Question"])})
171
+ for _, row in df.iterrows()
172
+ ]
173
+
174
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
175
  vectorstore = FAISS.from_documents(docs, embedding=embeddings)
176
  vectorstore.save_local(store_dir)
177
  return vectorstore
178
+
179
+ except Exception as e:
180
+ raise RuntimeError(f"Error building/loading vector store: {str(e)}")
181
 
 
182
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
183
  class GeminiLangChainLLM(LLM):
184
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
 
189
  def _llm_type(self) -> str:
190
  return "custom_gemini"
191
 
192
+ try:
193
+ retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
194
+ gemini_as_llm = GeminiLangChainLLM()
195
+ return RetrievalQA.from_chain_type(
196
+ llm=gemini_as_llm,
197
+ chain_type="stuff",
198
+ retriever=retriever,
199
+ return_source_documents=True
200
+ )
201
+ except Exception as e:
202
+ raise RuntimeError(f"Error building RAG chain: {str(e)}")
203
+
204
  def do_web_search(query: str) -> str:
205
+ try:
206
+ query_input = QueryInput(query=query)
207
+ search_tool = DuckDuckGoSearchTool()
208
+ web_agent = CodeAgent(tools=[search_tool], model=pydantic_agent)
209
+ managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Performs web searches")
210
+ manager_agent = CodeAgent(tools=[], model=pydantic_agent, managed_agents=[managed_web_agent])
211
+
212
+ search_query = f"Give me relevant info: {query_input.query}"
213
+ return manager_agent.run(search_query)
214
+ except Exception as e:
215
+ return f"Web search failed: {str(e)}"
216
 
 
217
  def merge_responses(kb_answer: str, web_answer: str) -> str:
218
+ try:
219
+ if not kb_answer and not web_answer:
220
+ return "No relevant information found."
221
+
222
+ if not web_answer:
223
+ return kb_answer.strip()
224
+
225
+ if not kb_answer:
226
+ return web_answer.strip()
227
+
228
+ return f"Knowledge Base Answer: {kb_answer.strip()}\n\nWeb Search Result: {web_answer.strip()}"
229
+ except Exception as e:
230
+ return f"Error merging responses: {str(e)}"
231
 
 
232
  def run_pipeline(query: str) -> str:
233
+ try:
234
+ # Validate and moderate input
235
+ moderation_result = moderate_text(query)
236
+ if not moderation_result.is_safe:
237
+ return "Sorry, this query contains harmful or inappropriate content."
238
+
239
+ # Classify the query
240
+ classification_result = classify_query(moderation_result.original_text)
241
 
242
+ if classification_result.category == "OutOfScope":
243
+ refusal_text = refusal_chain.run({"topic": "this topic"})
244
+ return tailor_chain.run({"response": refusal_text}).strip()
245
 
246
+ # Handle different classifications
247
+ if classification_result.category == "Wellness":
248
+ rag_result = wellness_rag_chain({"query": moderation_result.original_text})
249
+ csv_answer = rag_result["result"].strip()
250
+ web_answer = "" if csv_answer else do_web_search(moderation_result.original_text)
251
+ final_merged = merge_responses(csv_answer, web_answer)
252
+ return tailor_chain.run({"response": final_merged}).strip()
253
+
254
+ if classification_result.category == "Brand":
255
+ rag_result = brand_rag_chain({"query": moderation_result.original_text})
256
+ csv_answer = rag_result["result"].strip()
257
+ final_merged = merge_responses(csv_answer, "")
258
+ return tailor_chain.run({"response": final_merged}).strip()
259
+
260
+ # Default fallback
261
  refusal_text = refusal_chain.run({"topic": "this topic"})
262
+ return tailor_chain.run({"response": refusal_text}).strip()
263
+
264
+ except Exception as e:
265
+ return f"An error occurred while processing your request: {str(e)}"
266
+
267
+ # Initialize chains and vectorstores
268
+ try:
269
+ classification_chain = get_classification_chain()
270
+ refusal_chain = get_refusal_chain()
271
+ tailor_chain = get_tailor_chain()
272
+ cleaner_chain = get_cleaner_chain()
273
+
274
+ wellness_csv = "AIChatbot.csv"
275
+ brand_csv = "BrandAI.csv"
276
+ wellness_store_dir = "faiss_wellness_store"
277
+ brand_store_dir = "faiss_brand_store"
278
+
279
+ wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
280
+ brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
281
+
282
+ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
283
+ wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
284
+ brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
285
+
286
+ print("Pipeline initialized successfully!")
287
+ except Exception as e:
288
+ print(f"Error initializing pipeline: {str(e)}")
289
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  def run_with_chain(query: str) -> str:
291
+ return run_pipeline(query)