Phoenix21 commited on
Commit
c947e4c
·
verified ·
1 Parent(s): a684f83

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +18 -18
pipeline.py CHANGED
@@ -25,6 +25,9 @@ from prompts import classification_prompt, refusal_prompt, tailor_prompt
25
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
26
  client = Mistral(api_key=mistral_api_key)
27
 
 
 
 
28
  # Pydantic models for validation and type safety
29
  class QueryInput(BaseModel):
30
  query: str = Field(..., min_length=1, description="The input query string")
@@ -51,6 +54,14 @@ class RAGResponse(BaseModel):
51
  sources: List[str] = Field(default_factory=list, description="Source documents used")
52
  confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score of the answer")
53
 
 
 
 
 
 
 
 
 
54
  # Load spaCy model for NER
55
  def install_spacy_model():
56
  try:
@@ -70,13 +81,11 @@ def extract_main_topic(query: str) -> str:
70
  doc = nlp(query_input.query)
71
  main_topic = None
72
 
73
- # Try to find named entities first
74
  for ent in doc.ents:
75
  if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]:
76
  main_topic = ent.text
77
  break
78
 
79
- # If no named entities found, look for nouns
80
  if not main_topic:
81
  for token in doc:
82
  if token.pos_ in ["NOUN", "PROPN"]:
@@ -157,7 +166,6 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
157
  df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
158
  df.columns = df.columns.str.strip()
159
 
160
- # Handle column name variations
161
  if "Answer" in df.columns:
162
  df.rename(columns={"Answer": "Answers"}, inplace=True)
163
  if "Question" not in df.columns and "Question " in df.columns:
@@ -228,25 +236,13 @@ def merge_responses(kb_answer: str, web_answer: str) -> str:
228
  return f"Knowledge Base Answer: {kb_answer.strip()}\n\nWeb Search Result: {web_answer.strip()}"
229
  except Exception as e:
230
  return f"Error merging responses: {str(e)}"
231
- def sanitize_message(message: Any) -> str:
232
- """Sanitize message input to ensure it's a valid string."""
233
- if hasattr(message, 'content'):
234
- return str(message.content)
235
- if isinstance(message, (list, dict)):
236
- return str(message)
237
- return str(message)
238
 
239
- # Modify your run_pipeline function to include the sanitization
240
  def run_pipeline(query: str) -> str:
241
  try:
242
- # Sanitize input
243
  query = sanitize_message(query)
244
 
245
- # Rest of your pipeline code...
246
- moderation_result = moderate_text(query)
247
- if not moderation_result.is_safe:
248
- return "Sorry, this query contains harmful or inappropriate content."
249
- # Validate and moderate input
250
  moderation_result = moderate_text(query)
251
  if not moderation_result.is_safe:
252
  return "Sorry, this query contains harmful or inappropriate content."
@@ -254,11 +250,11 @@ def run_pipeline(query: str) -> str:
254
  # Classify the query
255
  classification_result = classify_query(moderation_result.original_text)
256
 
 
257
  if classification_result.category == "OutOfScope":
258
  refusal_text = refusal_chain.run({"topic": "this topic"})
259
  return tailor_chain.run({"response": refusal_text}).strip()
260
 
261
- # Handle different classifications
262
  if classification_result.category == "Wellness":
263
  rag_result = wellness_rag_chain({"query": moderation_result.original_text})
264
  csv_answer = rag_result["result"].strip()
@@ -281,19 +277,23 @@ def run_pipeline(query: str) -> str:
281
 
282
  # Initialize chains and vectorstores
283
  try:
 
284
  classification_chain = get_classification_chain()
285
  refusal_chain = get_refusal_chain()
286
  tailor_chain = get_tailor_chain()
287
  cleaner_chain = get_cleaner_chain()
288
 
 
289
  wellness_csv = "AIChatbot.csv"
290
  brand_csv = "BrandAI.csv"
291
  wellness_store_dir = "faiss_wellness_store"
292
  brand_store_dir = "faiss_brand_store"
293
 
 
294
  wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
295
  brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
296
 
 
297
  gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
298
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
299
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
 
25
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
26
  client = Mistral(api_key=mistral_api_key)
27
 
28
+ # Initialize LiteLLM model for web search
29
+ pydantic_agent = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
30
+
31
  # Pydantic models for validation and type safety
32
  class QueryInput(BaseModel):
33
  query: str = Field(..., min_length=1, description="The input query string")
 
54
  sources: List[str] = Field(default_factory=list, description="Source documents used")
55
  confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score of the answer")
56
 
57
+ def sanitize_message(message: Any) -> str:
58
+ """Sanitize message input to ensure it's a valid string."""
59
+ if hasattr(message, 'content'):
60
+ return str(message.content)
61
+ if isinstance(message, (list, dict)):
62
+ return str(message)
63
+ return str(message)
64
+
65
  # Load spaCy model for NER
66
  def install_spacy_model():
67
  try:
 
81
  doc = nlp(query_input.query)
82
  main_topic = None
83
 
 
84
  for ent in doc.ents:
85
  if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]:
86
  main_topic = ent.text
87
  break
88
 
 
89
  if not main_topic:
90
  for token in doc:
91
  if token.pos_ in ["NOUN", "PROPN"]:
 
166
  df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
167
  df.columns = df.columns.str.strip()
168
 
 
169
  if "Answer" in df.columns:
170
  df.rename(columns={"Answer": "Answers"}, inplace=True)
171
  if "Question" not in df.columns and "Question " in df.columns:
 
236
  return f"Knowledge Base Answer: {kb_answer.strip()}\n\nWeb Search Result: {web_answer.strip()}"
237
  except Exception as e:
238
  return f"Error merging responses: {str(e)}"
 
 
 
 
 
 
 
239
 
 
240
  def run_pipeline(query: str) -> str:
241
  try:
242
+ # Sanitize and validate input
243
  query = sanitize_message(query)
244
 
245
+ # Moderate content
 
 
 
 
246
  moderation_result = moderate_text(query)
247
  if not moderation_result.is_safe:
248
  return "Sorry, this query contains harmful or inappropriate content."
 
250
  # Classify the query
251
  classification_result = classify_query(moderation_result.original_text)
252
 
253
+ # Handle different classifications
254
  if classification_result.category == "OutOfScope":
255
  refusal_text = refusal_chain.run({"topic": "this topic"})
256
  return tailor_chain.run({"response": refusal_text}).strip()
257
 
 
258
  if classification_result.category == "Wellness":
259
  rag_result = wellness_rag_chain({"query": moderation_result.original_text})
260
  csv_answer = rag_result["result"].strip()
 
277
 
278
  # Initialize chains and vectorstores
279
  try:
280
+ # Initialize chain components
281
  classification_chain = get_classification_chain()
282
  refusal_chain = get_refusal_chain()
283
  tailor_chain = get_tailor_chain()
284
  cleaner_chain = get_cleaner_chain()
285
 
286
+ # Set up paths
287
  wellness_csv = "AIChatbot.csv"
288
  brand_csv = "BrandAI.csv"
289
  wellness_store_dir = "faiss_wellness_store"
290
  brand_store_dir = "faiss_brand_store"
291
 
292
+ # Build or load vectorstores
293
  wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
294
  brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
295
 
296
+ # Initialize LLM and RAG chains
297
  gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
298
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
299
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)