Phoenix21 commited on
Commit
df1f812
·
verified ·
1 Parent(s): 3bc6f69

to handel pydantic error

Browse files
Files changed (1) hide show
  1. pipeline.py +115 -53
pipeline.py CHANGED
@@ -9,12 +9,13 @@ from langchain.docstore.document import Document
9
  from langchain.embeddings import HuggingFaceEmbeddings
10
  from langchain.vectorstores import FAISS
11
  from langchain.chains import RetrievalQA
12
- from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
  from pydantic import BaseModel, Field, ValidationError, validator
14
  from mistralai import Mistral
15
- from langchain.prompts import PromptTemplate
16
 
17
- # Import chains and tools
 
 
18
  from classification_chain import get_classification_chain
19
  from cleaner_chain import get_cleaner_chain
20
  from refusal_chain import get_refusal_chain
@@ -25,10 +26,25 @@ from prompts import classification_prompt, refusal_prompt, tailor_prompt
25
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
26
  client = Mistral(api_key=mistral_api_key)
27
 
28
- # Initialize LiteLLM model for web search
29
- pydantic_agent = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # Pydantic models for validation and type safety
32
  class QueryInput(BaseModel):
33
  query: str = Field(..., min_length=1, description="The input query string")
34
 
@@ -45,7 +61,6 @@ class ModerationResult(BaseModel):
45
  categories: Dict[str, bool] = Field(default_factory=dict, description="Detected content categories")
46
  original_text: str = Field(..., description="The original input text")
47
 
48
- # Load spaCy model for NER
49
  def install_spacy_model():
50
  try:
51
  spacy.load("en_core_web_sm")
@@ -58,6 +73,22 @@ def install_spacy_model():
58
  install_spacy_model()
59
  nlp = spacy.load("en_core_web_sm")
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def extract_main_topic(query: str) -> str:
62
  try:
63
  query_input = QueryInput(query=query)
@@ -160,55 +191,76 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
160
  except Exception as e:
161
  raise RuntimeError(f"Error building/loading vector store: {str(e)}")
162
 
163
- def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
164
- class GeminiLangChainLLM(LLM):
165
- def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
166
- messages = [{"role": "user", "content": prompt}]
167
- return llm_model(messages, stop_sequences=stop)
 
 
 
 
 
 
168
 
169
- @property
170
- def _llm_type(self) -> str:
171
- return "custom_gemini"
172
-
 
 
173
  try:
174
  retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
175
- gemini_as_llm = GeminiLangChainLLM()
176
- return RetrievalQA.from_chain_type(
177
- llm=gemini_as_llm,
178
  chain_type="stuff",
179
  retriever=retriever,
180
  return_source_documents=True
181
  )
 
182
  except Exception as e:
183
  raise RuntimeError(f"Error building RAG chain: {str(e)}")
184
-
185
-
186
- def sanitize_message(message: Any) -> str:
187
- """Sanitize message input to ensure it's a valid string."""
188
  try:
189
- if hasattr(message, 'content'):
190
- return str(message.content)
191
- if isinstance(message, (list, dict)):
192
- return str(message)
193
- return str(message)
194
  except Exception as e:
195
- raise RuntimeError(f"Error in sanitize function: {str(e)}")
 
 
 
 
 
 
196
 
197
-
 
 
 
 
 
 
 
 
 
 
198
  def run_pipeline(query: str) -> str:
199
  try:
200
- query = sanitize_message(query)
201
- topic=extract_main_topic(query)
202
- moderation_result = moderate_text(query)
203
- try:
204
- if not moderation_result.is_safe:
205
- return "Sorry, this query contains harmful or inappropriate content."
206
- except Exception as e:
207
- raise RuntimeError(f"Error in run_runpipeline check moderation: {str(e)}")
208
- try:
209
- classification = classify_query(moderation_result.original_text)
210
- except Exception as e:
211
- raise RuntimeError(f"Error in run_runpipeline check classify_query: {str(e)}")
212
 
213
  if classification == "OutOfScope":
214
  refusal_text = refusal_chain.run({"topic": topic})
@@ -216,22 +268,37 @@ def run_pipeline(query: str) -> str:
216
 
217
  if classification == "Wellness":
218
  rag_result = wellness_rag_chain({"query": moderation_result.original_text})
219
- csv_answer = rag_result["result"].strip()
 
 
 
220
  web_answer = "" if csv_answer else do_web_search(moderation_result.original_text)
221
  final_merged = merge_responses(csv_answer, web_answer)
222
  return tailor_chain.run({"response": final_merged}).strip()
223
 
224
  if classification == "Brand":
225
  rag_result = brand_rag_chain({"query": moderation_result.original_text})
226
- csv_answer = rag_result["result"].strip()
 
 
 
227
  final_merged = merge_responses(csv_answer, "")
228
  return tailor_chain.run({"response": final_merged}).strip()
229
 
230
  refusal_text = refusal_chain.run({"topic": topic})
231
  return tailor_chain.run({"response": refusal_text}).strip()
232
- except Exception as e:
233
- raise RuntimeError(f"Error in run_runpipeline: {str(e)}")
234
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  # Initialize chains and vectorstores
237
  classification_chain = get_classification_chain()
@@ -247,12 +314,7 @@ brand_store_dir = "faiss_brand_store"
247
  wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
248
  brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
249
 
250
- gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
251
- wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
252
- brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
253
 
254
  print("Pipeline initialized successfully!")
255
-
256
-
257
- def run_with_chain(query: str) -> str:
258
- return run_pipeline(query)
 
9
  from langchain.embeddings import HuggingFaceEmbeddings
10
  from langchain.vectorstores import FAISS
11
  from langchain.chains import RetrievalQA
12
+ from smolagents import DuckDuckGoSearchTool, ManagedAgent
13
  from pydantic import BaseModel, Field, ValidationError, validator
14
  from mistralai import Mistral
 
15
 
16
+ # Import Google Gemini model
17
+ from langchain_google_genai import ChatGoogleGenerativeAI
18
+
19
  from classification_chain import get_classification_chain
20
  from cleaner_chain import get_cleaner_chain
21
  from refusal_chain import get_refusal_chain
 
26
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
27
  client = Mistral(api_key=mistral_api_key)
28
 
29
+ # Setup ChatGoogleGenerativeAI for Gemini
30
+ # Ensure GOOGLE_API_KEY is set in your environment variables.
31
+ gemini_llm = ChatGoogleGenerativeAI(
32
+ model="gemini-1.5-pro",
33
+ temperature=0,
34
+ max_retries=2,
35
+ # You can add additional parameters or safety_settings here if needed
36
+ )
37
+
38
+ # Initialize LiteLLM model for web search (if needed)
39
+ pydantic_agent = ManagedAgent(
40
+ llm=ChatGoogleGenerativeAI(
41
+ model="gemini-1.5-pro",
42
+ temperature=0,
43
+ max_retries=2,
44
+ ),
45
+ tools=[DuckDuckGoSearchTool()]
46
+ )
47
 
 
48
  class QueryInput(BaseModel):
49
  query: str = Field(..., min_length=1, description="The input query string")
50
 
 
61
  categories: Dict[str, bool] = Field(default_factory=dict, description="Detected content categories")
62
  original_text: str = Field(..., description="The original input text")
63
 
 
64
  def install_spacy_model():
65
  try:
66
  spacy.load("en_core_web_sm")
 
73
  install_spacy_model()
74
  nlp = spacy.load("en_core_web_sm")
75
 
76
+ def sanitize_message(message: Any) -> str:
77
+ """Sanitize message input to ensure it's a valid string."""
78
+ try:
79
+ if hasattr(message, 'content'):
80
+ return str(message.content).strip()
81
+ if isinstance(message, dict) and 'content' in message:
82
+ return str(message['content']).strip()
83
+ if isinstance(message, list) and len(message) > 0:
84
+ if isinstance(message[0], dict) and 'content' in message[0]:
85
+ return str(message[0]['content']).strip()
86
+ if hasattr(message[0], 'content'):
87
+ return str(message[0].content).strip()
88
+ return str(message).strip()
89
+ except Exception as e:
90
+ raise RuntimeError(f"Error in sanitize function: {str(e)}")
91
+
92
  def extract_main_topic(query: str) -> str:
93
  try:
94
  query_input = QueryInput(query=query)
 
191
  except Exception as e:
192
  raise RuntimeError(f"Error building/loading vector store: {str(e)}")
193
 
194
+ class GeminiLangChainLLM(LLM):
195
+ def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
196
+ """Call the Gemini model using ChatGoogleGenerativeAI and ensure string output."""
197
+ try:
198
+ # Construct message list for the Gemini model
199
+ messages = [("human", prompt)]
200
+ ai_msg = gemini_llm.invoke(messages)
201
+ return ai_msg.content.strip() if ai_msg and ai_msg.content else str(prompt)
202
+ except Exception as e:
203
+ print(f"Error in GeminiLangChainLLM._call: {e}")
204
+ return str(prompt) # Fallback to returning the prompt
205
 
206
+ @property
207
+ def _llm_type(self) -> str:
208
+ return "custom_gemini"
209
+
210
+ def build_rag_chain(vectorstore: FAISS) -> RetrievalQA:
211
+ """Build RAG chain with enhanced error handling."""
212
  try:
213
  retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
214
+ gemini_llm_instance = GeminiLangChainLLM()
215
+ chain = RetrievalQA.from_chain_type(
216
+ llm=gemini_llm_instance,
217
  chain_type="stuff",
218
  retriever=retriever,
219
  return_source_documents=True
220
  )
221
+ return chain
222
  except Exception as e:
223
  raise RuntimeError(f"Error building RAG chain: {str(e)}")
224
+
225
+ def do_web_search(query: str) -> str:
 
 
226
  try:
227
+ search_tool = DuckDuckGoSearchTool()
228
+ search_agent = ManagedAgent(llm=gemini_llm, tools=[search_tool])
229
+ search_result = search_agent.run(f"Search for information about: {query}")
230
+ return str(search_result).strip()
 
231
  except Exception as e:
232
+ print(f"Web search failed: {e}")
233
+ return ""
234
+
235
+ def merge_responses(csv_answer: str, web_answer: str) -> str:
236
+ try:
237
+ if not csv_answer and not web_answer:
238
+ return "I apologize, but I couldn't find any relevant information."
239
 
240
+ if not web_answer:
241
+ return csv_answer
242
+
243
+ if not csv_answer:
244
+ return web_answer
245
+
246
+ return f"{csv_answer}\n\nAdditional information from web search:\n{web_answer}"
247
+ except Exception as e:
248
+ print(f"Error merging responses: {e}")
249
+ return csv_answer or web_answer or "I apologize, but I couldn't process the information properly."
250
+
251
  def run_pipeline(query: str) -> str:
252
  try:
253
+ print(query)
254
+ sanitized_query = sanitize_message(query)
255
+ query_input = QueryInput(query=sanitized_query)
256
+
257
+ topic = extract_main_topic(query_input.query)
258
+ moderation_result = moderate_text(query_input.query)
259
+
260
+ if not moderation_result.is_safe:
261
+ return "Sorry, this query contains harmful or inappropriate content."
262
+
263
+ classification = classify_query(moderation_result.original_text)
 
264
 
265
  if classification == "OutOfScope":
266
  refusal_text = refusal_chain.run({"topic": topic})
 
268
 
269
  if classification == "Wellness":
270
  rag_result = wellness_rag_chain({"query": moderation_result.original_text})
271
+ if isinstance(rag_result, dict) and "result" in rag_result:
272
+ csv_answer = str(rag_result["result"]).strip()
273
+ else:
274
+ csv_answer = str(rag_result).strip()
275
  web_answer = "" if csv_answer else do_web_search(moderation_result.original_text)
276
  final_merged = merge_responses(csv_answer, web_answer)
277
  return tailor_chain.run({"response": final_merged}).strip()
278
 
279
  if classification == "Brand":
280
  rag_result = brand_rag_chain({"query": moderation_result.original_text})
281
+ if isinstance(rag_result, dict) and "result" in rag_result:
282
+ csv_answer = str(rag_result["result"]).strip()
283
+ else:
284
+ csv_answer = str(rag_result).strip()
285
  final_merged = merge_responses(csv_answer, "")
286
  return tailor_chain.run({"response": final_merged}).strip()
287
 
288
  refusal_text = refusal_chain.run({"topic": topic})
289
  return tailor_chain.run({"response": refusal_text}).strip()
 
 
290
 
291
+ except ValidationError as e:
292
+ raise ValueError(f"Input validation failed: {str(e)}")
293
+ except Exception as e:
294
+ raise RuntimeError(f"Error in run_pipeline: {str(e)}")
295
+
296
+ def run_with_chain(query: str) -> str:
297
+ try:
298
+ return run_pipeline(query)
299
+ except Exception as e:
300
+ print(f"Error in run_with_chain: {str(e)}")
301
+ return "I apologize, but I encountered an error processing your request. Please try again."
302
 
303
  # Initialize chains and vectorstores
304
  classification_chain = get_classification_chain()
 
314
  wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
315
  brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
316
 
317
+ wellness_rag_chain = build_rag_chain(wellness_vectorstore)
318
+ brand_rag_chain = build_rag_chain(brand_vectorstore)
 
319
 
320
  print("Pipeline initialized successfully!")