Phoenix21 commited on
Commit
53b33ac
·
verified ·
1 Parent(s): 293661c

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +108 -8
pipeline.py CHANGED
@@ -4,6 +4,7 @@ import spacy
4
  import pandas as pd
5
  from typing import Optional
6
  import subprocess
 
7
  from langchain.llms.base import LLM
8
  from langchain.docstore.document import Document
9
  from langchain.embeddings import HuggingFaceEmbeddings
@@ -12,7 +13,14 @@ from langchain.chains import RetrievalQA
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
  from pydantic_ai import Agent # Import Pydantic AI's Agent
14
  from mistralai import Mistral
15
- import asyncio # Needed for managing async tasks
 
 
 
 
 
 
 
16
 
17
  # Initialize Mistral API client
18
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
@@ -54,23 +62,30 @@ def classify_query(query: str) -> str:
54
  wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
55
  if any(keyword in query.lower() for keyword in wellness_keywords):
56
  return "Wellness"
 
57
  class_result = classification_chain.invoke({"query": query})
58
  classification = class_result.get("text", "").strip()
59
  return classification if classification != "OutOfScope" else "OutOfScope"
60
 
61
- # Function to moderate text using Mistral moderation API (async version)
62
- async def moderate_text(query: str) -> str:
63
  try:
64
- await pydantic_agent.run(query) # Use async run for Pydantic validation
 
65
  except Exception as e:
66
  print(f"Error validating text: {e}")
67
  return "Invalid text format."
68
 
69
- response = await client.classifiers.moderate_chat(
 
70
  model="mistral-moderation-latest",
71
  inputs=[{"role": "user", "content": query}]
72
  )
 
 
73
  categories = response['results'][0]['categories']
 
 
74
  if categories.get("violence_and_threats", False) or \
75
  categories.get("hate_and_discrimination", False) or \
76
  categories.get("dangerous_and_criminal_content", False) or \
@@ -79,7 +94,74 @@ async def moderate_text(query: str) -> str:
79
 
80
  return query
81
 
82
- # Use the event loop to run the async functions properly
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  async def run_async_pipeline(query: str) -> str:
84
  # Moderate the query for harmful content (async)
85
  moderated_query = await moderate_text(query)
@@ -100,14 +182,14 @@ async def run_async_pipeline(query: str) -> str:
100
  web_answer = "" # Empty if we found an answer from the knowledge base
101
  if not csv_answer:
102
  web_answer = await do_web_search(moderated_query)
103
- final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
104
  final_answer = tailor_chain.run({"response": final_merged})
105
  return final_answer.strip()
106
 
107
  if classification == "Brand":
108
  rag_result = brand_rag_chain({"query": moderated_query})
109
  csv_answer = rag_result["result"].strip()
110
- final_merged = cleaner_chain.merge(kb=csv_answer, web="")
111
  final_answer = tailor_chain.run({"response": final_merged})
112
  return final_answer.strip()
113
 
@@ -118,3 +200,21 @@ async def run_async_pipeline(query: str) -> str:
118
  # Run the pipeline with the event loop
119
  def run_with_chain(query: str) -> str:
120
  return asyncio.run(run_async_pipeline(query))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import pandas as pd
5
  from typing import Optional
6
  import subprocess
7
+ import asyncio # Needed for managing async tasks
8
  from langchain.llms.base import LLM
9
  from langchain.docstore.document import Document
10
  from langchain.embeddings import HuggingFaceEmbeddings
 
13
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
14
  from pydantic_ai import Agent # Import Pydantic AI's Agent
15
  from mistralai import Mistral
16
+ from langchain.prompts import PromptTemplate
17
+
18
+ # Import chains and tools
19
+ from classification_chain import get_classification_chain
20
+ from cleaner_chain import get_cleaner_chain
21
+ from refusal_chain import get_refusal_chain
22
+ from tailor_chain import get_tailor_chain
23
+ from prompts import classification_prompt, refusal_prompt, tailor_prompt
24
 
25
  # Initialize Mistral API client
26
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
 
62
  wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
63
  if any(keyword in query.lower() for keyword in wellness_keywords):
64
  return "Wellness"
65
+ # Fallback to classification chain if not directly recognized
66
  class_result = classification_chain.invoke({"query": query})
67
  classification = class_result.get("text", "").strip()
68
  return classification if classification != "OutOfScope" else "OutOfScope"
69
 
70
+ # Function to moderate text using Mistral moderation API (sync version)
71
+ def moderate_text(query: str) -> str:
72
  try:
73
+ # Use Pydantic AI for text validation synchronously
74
+ pydantic_agent.run(query) # This is a synchronous call
75
  except Exception as e:
76
  print(f"Error validating text: {e}")
77
  return "Invalid text format."
78
 
79
+ # Mistral moderation, no need for await as it's synchronous
80
+ response = client.classifiers.moderate_chat(
81
  model="mistral-moderation-latest",
82
  inputs=[{"role": "user", "content": query}]
83
  )
84
+
85
+ # Extract moderation categories
86
  categories = response['results'][0]['categories']
87
+
88
+ # Check for harmful categories and return "OutOfScope" if any are found
89
  if categories.get("violence_and_threats", False) or \
90
  categories.get("hate_and_discrimination", False) or \
91
  categories.get("dangerous_and_criminal_content", False) or \
 
94
 
95
  return query
96
 
97
+ # Function to build or load the vector store from CSV data
98
+ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
99
+ if os.path.exists(store_dir):
100
+ print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
101
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
102
+ vectorstore = FAISS.load_local(store_dir, embeddings)
103
+ return vectorstore
104
+ else:
105
+ print(f"DEBUG: Building new store from CSV: {csv_path}")
106
+ df = pd.read_csv(csv_path)
107
+ df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
108
+ df.columns = df.columns.str.strip()
109
+ if "Answer" in df.columns:
110
+ df.rename(columns={"Answer": "Answers"}, inplace=True)
111
+ if "Question" not in df.columns and "Question " in df.columns:
112
+ df.rename(columns={"Question ": "Question"}, inplace=True)
113
+ if "Question" not in df.columns or "Answers" not in df.columns:
114
+ raise ValueError("CSV must have 'Question' and 'Answers' columns.")
115
+ docs = []
116
+ for _, row in df.iterrows():
117
+ q = str(row["Question"])
118
+ ans = str(row["Answers"])
119
+ doc = Document(page_content=ans, metadata={"question": q})
120
+ docs.append(doc)
121
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
122
+ vectorstore = FAISS.from_documents(docs, embedding=embeddings)
123
+ vectorstore.save_local(store_dir)
124
+ return vectorstore
125
+
126
+ # Function to build RAG chain
127
+ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
128
+ class GeminiLangChainLLM(LLM):
129
+ def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
130
+ messages = [{"role": "user", "content": prompt}]
131
+ return llm_model(messages, stop_sequences=stop)
132
+
133
+ @property
134
+ def _llm_type(self) -> str:
135
+ return "custom_gemini"
136
+
137
+ retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
138
+ gemini_as_llm = GeminiLangChainLLM()
139
+ rag_chain = RetrievalQA.from_chain_type(
140
+ llm=gemini_as_llm,
141
+ chain_type="stuff",
142
+ retriever=retriever,
143
+ return_source_documents=True
144
+ )
145
+ return rag_chain
146
+
147
+ # Function to perform web search using DuckDuckGo
148
+ async def do_web_search(query: str) -> str:
149
+ search_tool = DuckDuckGoSearchTool()
150
+ web_agent = CodeAgent(tools=[search_tool], model=pydantic_agent)
151
+ managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
152
+ manager_agent = CodeAgent(tools=[], model=pydantic_agent, managed_agents=[managed_web_agent])
153
+
154
+ search_query = f"Give me relevant info: {query}"
155
+ response = manager_agent.run(search_query)
156
+ return response
157
+
158
+ # Function to combine web and knowledge base responses
159
+ async def merge_responses(kb_answer: str, web_answer: str) -> str:
160
+ # Merge both answers with a cohesive response
161
+ final_answer = f"Knowledge Base Answer: {kb_answer}\n\nWeb Search Result: {web_answer}"
162
+ return final_answer.strip()
163
+
164
+ # Orchestrate the entire workflow
165
  async def run_async_pipeline(query: str) -> str:
166
  # Moderate the query for harmful content (async)
167
  moderated_query = await moderate_text(query)
 
182
  web_answer = "" # Empty if we found an answer from the knowledge base
183
  if not csv_answer:
184
  web_answer = await do_web_search(moderated_query)
185
+ final_merged = await merge_responses(csv_answer, web_answer)
186
  final_answer = tailor_chain.run({"response": final_merged})
187
  return final_answer.strip()
188
 
189
  if classification == "Brand":
190
  rag_result = brand_rag_chain({"query": moderated_query})
191
  csv_answer = rag_result["result"].strip()
192
+ final_merged = await merge_responses(csv_answer, "")
193
  final_answer = tailor_chain.run({"response": final_merged})
194
  return final_answer.strip()
195
 
 
200
  # Run the pipeline with the event loop
201
  def run_with_chain(query: str) -> str:
202
  return asyncio.run(run_async_pipeline(query))
203
+
204
+ # Initialize chains here
205
+ classification_chain = get_classification_chain()
206
+ refusal_chain = get_refusal_chain()
207
+ tailor_chain = get_tailor_chain()
208
+ cleaner_chain = get_cleaner_chain()
209
+
210
+ wellness_csv = "AIChatbot.csv"
211
+ brand_csv = "BrandAI.csv"
212
+ wellness_store_dir = "faiss_wellness_store"
213
+ brand_store_dir = "faiss_brand_store"
214
+
215
+ wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
216
+ brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
217
+
218
+ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
219
+ wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
220
+ brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)