Phoenix21 commited on
Commit
293661c
·
verified ·
1 Parent(s): cd1a95e

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +30 -188
pipeline.py CHANGED
@@ -3,20 +3,19 @@ import getpass
3
  import spacy
4
  import pandas as pd
5
  from typing import Optional
 
 
6
  from langchain.docstore.document import Document
7
  from langchain.embeddings import HuggingFaceEmbeddings
8
  from langchain.vectorstores import FAISS
9
  from langchain.chains import RetrievalQA
10
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
11
- import subprocess
12
- from langchain.llms.base import LLM
13
-
14
- # Mistral Client Setup
15
- from mistralai import Mistral
16
  from pydantic_ai import Agent # Import Pydantic AI's Agent
 
 
17
 
18
  # Initialize Mistral API client
19
- mistral_api_key = os.environ.get("MISTRAL_API_KEY") # Ensure your Mistral API key is set
20
  client = Mistral(api_key=mistral_api_key)
21
 
22
  # Initialize Pydantic AI Agent (for text validation)
@@ -37,47 +36,40 @@ nlp = spacy.load("en_core_web_sm")
37
 
38
  # Function to extract the main topic from the query using spaCy NER
39
  def extract_main_topic(query: str) -> str:
40
- """
41
- Extracts the main topic from the user's query using spaCy's NER.
42
- Returns the first named entity or noun found in the query.
43
- """
44
  doc = nlp(query)
45
-
46
- # Try to extract the main topic as a named entity (person, product, etc.)
47
  main_topic = None
48
  for ent in doc.ents:
49
- # Filter for specific entity types (you can adjust this based on your needs)
50
- if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]: # Add more entity labels as needed
51
  main_topic = ent.text
52
  break
53
-
54
- # If no named entity found, fallback to extracting the first noun or proper noun
55
  if not main_topic:
56
  for token in doc:
57
- if token.pos_ in ["NOUN", "PROPN"]: # Extract first noun or proper noun
58
  main_topic = token.text
59
  break
60
-
61
- # Return the extracted topic or a fallback value if no topic is found
62
  return main_topic if main_topic else "this topic"
63
 
64
- # Function to moderate text using Mistral moderation API
65
- def moderate_text(query: str) -> str:
66
- """
67
- Classifies the query as harmful or not using Mistral Moderation via Mistral API.
68
- Returns "OutOfScope" if harmful, otherwise returns the original query.
69
- """
 
 
 
 
 
70
  try:
71
- pydantic_agent.run_sync(query) # Validate input
72
  except Exception as e:
73
  print(f"Error validating text: {e}")
74
  return "Invalid text format."
75
 
76
- response = client.classifiers.moderate_chat(
77
  model="mistral-moderation-latest",
78
  inputs=[{"role": "user", "content": query}]
79
  )
80
-
81
  categories = response['results'][0]['categories']
82
  if categories.get("violence_and_threats", False) or \
83
  categories.get("hate_and_discrimination", False) or \
@@ -87,163 +79,15 @@ def moderate_text(query: str) -> str:
87
 
88
  return query
89
 
90
- # Build or load vectorstore function
91
- def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
92
- if os.path.exists(store_dir):
93
- print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
94
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
95
- vectorstore = FAISS.load_local(store_dir, embeddings)
96
- return vectorstore
97
- else:
98
- print(f"DEBUG: Building new store from CSV: {csv_path}")
99
- df = pd.read_csv(csv_path)
100
- df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
101
- df.columns = df.columns.str.strip()
102
- if "Answer" in df.columns:
103
- df.rename(columns={"Answer": "Answers"}, inplace=True)
104
- if "Question" not in df.columns and "Question " in df.columns:
105
- df.rename(columns={"Question ": "Question"}, inplace=True)
106
- if "Question" not in df.columns or "Answers" not in df.columns:
107
- raise ValueError("CSV must have 'Question' and 'Answers' columns.")
108
- docs = []
109
- for _, row in df.iterrows():
110
- q = str(row["Question"])
111
- ans = str(row["Answers"])
112
- doc = Document(page_content=ans, metadata={"question": q})
113
- docs.append(doc)
114
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
115
- vectorstore = FAISS.from_documents(docs, embedding=embeddings)
116
- vectorstore.save_local(store_dir)
117
- return vectorstore
118
-
119
- # Build RAG chain for Gemini (no changes)
120
- def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
121
- class GeminiLangChainLLM(LLM):
122
- def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
123
- messages = [{"role": "user", "content": prompt}]
124
- return llm_model(messages, stop_sequences=stop)
125
-
126
- @property
127
- def _llm_type(self) -> str:
128
- return "custom_gemini"
129
-
130
- retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
131
- gemini_as_llm = GeminiLangChainLLM()
132
- rag_chain = RetrievalQA.from_chain_type(
133
- llm=gemini_as_llm,
134
- chain_type="stuff",
135
- retriever=retriever,
136
- return_source_documents=True
137
- )
138
- return rag_chain
139
-
140
- # Initialize all the separate chains
141
- from classification_chain import get_classification_chain
142
- from refusal_chain import get_refusal_chain
143
- from tailor_chain import get_tailor_chain
144
- from cleaner_chain import get_cleaner_chain
145
-
146
- classification_chain = get_classification_chain() # Ensure this function is imported correctly
147
- refusal_chain = get_refusal_chain() # Ensure this function is imported correctly
148
- tailor_chain = get_tailor_chain() # Ensure this function is imported correctly
149
- cleaner_chain = get_cleaner_chain() # Ensure this function is imported correctly
150
-
151
- # Build our vectorstores + RAG chains
152
- wellness_csv = "AIChatbot.csv"
153
- brand_csv = "BrandAI.csv"
154
- wellness_store_dir = "faiss_wellness_store"
155
- brand_store_dir = "faiss_brand_store"
156
-
157
- wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
158
- brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
159
-
160
- gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
161
- wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
162
- brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
163
-
164
- # Tools / Agents for web search
165
- search_tool = DuckDuckGoSearchTool()
166
- web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
167
- managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
168
- manager_agent = CodeAgent(tools=[], model=gemini_llm, managed_agents=[managed_web_agent])
169
-
170
- def do_web_search(query: str) -> str:
171
- print("DEBUG: Attempting web search for more info...")
172
- search_query = f"Give me relevant info: {query}"
173
- response = manager_agent.run(search_query)
174
- return response
175
-
176
- # Modify the classification logic to recognize box breathing
177
- def classify_query(query: str) -> str:
178
- wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
179
-
180
- if any(keyword in query.lower() for keyword in wellness_keywords):
181
- return "Wellness"
182
-
183
- # If not recognized as wellness, use the classification chain
184
- class_result = classification_chain.invoke({"query": query})
185
- classification = class_result.get("text", "").strip()
186
-
187
- if classification == "OutOfScope":
188
- return "OutOfScope"
189
- return classification
190
-
191
- # # Orchestrator: run_with_chain
192
- # def run_with_chain(query: str) -> str:
193
- # print("DEBUG: Starting run_with_chain...")
194
-
195
- # # Moderate the query for harmful content
196
- # moderated_query = moderate_text(query)
197
- # if moderated_query == "OutOfScope":
198
- # return "Sorry, this query contains harmful or inappropriate content."
199
-
200
- # # Classify the query
201
- # class_result = classification_chain.invoke({"query": moderated_query})
202
- # classification = class_result.get("text", "").strip()
203
- # print("DEBUG: Classification =>", classification)
204
-
205
- # if classification == "OutOfScope":
206
- # refusal_text = refusal_chain.run({"topic": "this topic"})
207
- # final_refusal = tailor_chain.run({"response": refusal_text})
208
- # return final_refusal.strip()
209
-
210
- # if classification == "Wellness":
211
- # rag_result = wellness_rag_chain({"query": moderated_query})
212
- # csv_answer = rag_result["result"].strip()
213
- # if not csv_answer:
214
- # web_answer = do_web_search(moderated_query)
215
- # else:
216
- # lower_ans = csv_answer.lower()
217
- # if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
218
- # web_answer = do_web_search(moderated_query)
219
- # else:
220
- # web_answer = ""
221
- # final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
222
- # final_answer = tailor_chain.run({"response": final_merged})
223
- # return final_answer.strip()
224
-
225
- # if classification == "Brand":
226
- # rag_result = brand_rag_chain({"query": moderated_query})
227
- # csv_answer = rag_result["result"].strip()
228
- # final_merged = cleaner_chain.merge(kb=csv_answer, web="")
229
- # final_answer = tailor_chain.run({"response": final_merged})
230
- # return final_answer.strip()
231
-
232
- # refusal_text = refusal_chain.run({"topic": "this topic"})
233
- # final_refusal = tailor_chain.run({"response": refusal_text})
234
- # return final_refusal.strip()
235
-
236
- def run_with_chain(query: str) -> str:
237
- print("DEBUG: Starting run_with_chain...")
238
-
239
- # Moderate the query for harmful content
240
- moderated_query = moderate_text(query)
241
  if moderated_query == "OutOfScope":
242
  return "Sorry, this query contains harmful or inappropriate content."
243
 
244
- # Classify the query manually, ensuring box breathing is recognized
245
  classification = classify_query(moderated_query)
246
- print("DEBUG: Classification =>", classification)
247
 
248
  if classification == "OutOfScope":
249
  refusal_text = refusal_chain.run({"topic": "this topic"})
@@ -253,14 +97,9 @@ def run_with_chain(query: str) -> str:
253
  if classification == "Wellness":
254
  rag_result = wellness_rag_chain({"query": moderated_query})
255
  csv_answer = rag_result["result"].strip()
 
256
  if not csv_answer:
257
- web_answer = do_web_search(moderated_query)
258
- else:
259
- lower_ans = csv_answer.lower()
260
- if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
261
- web_answer = do_web_search(moderated_query)
262
- else:
263
- web_answer = ""
264
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
265
  final_answer = tailor_chain.run({"response": final_merged})
266
  return final_answer.strip()
@@ -276,3 +115,6 @@ def run_with_chain(query: str) -> str:
276
  final_refusal = tailor_chain.run({"response": refusal_text})
277
  return final_refusal.strip()
278
 
 
 
 
 
3
  import spacy
4
  import pandas as pd
5
  from typing import Optional
6
+ import subprocess
7
+ from langchain.llms.base import LLM
8
  from langchain.docstore.document import Document
9
  from langchain.embeddings import HuggingFaceEmbeddings
10
  from langchain.vectorstores import FAISS
11
  from langchain.chains import RetrievalQA
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
 
 
 
 
 
13
  from pydantic_ai import Agent # Import Pydantic AI's Agent
14
+ from mistralai import Mistral
15
+ import asyncio # Needed for managing async tasks
16
 
17
  # Initialize Mistral API client
18
+ mistral_api_key = os.environ.get("MISTRAL_API_KEY")
19
  client = Mistral(api_key=mistral_api_key)
20
 
21
  # Initialize Pydantic AI Agent (for text validation)
 
36
 
37
  # Function to extract the main topic from the query using spaCy NER
38
  def extract_main_topic(query: str) -> str:
 
 
 
 
39
  doc = nlp(query)
 
 
40
  main_topic = None
41
  for ent in doc.ents:
42
+ if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]:
 
43
  main_topic = ent.text
44
  break
 
 
45
  if not main_topic:
46
  for token in doc:
47
+ if token.pos_ in ["NOUN", "PROPN"]:
48
  main_topic = token.text
49
  break
 
 
50
  return main_topic if main_topic else "this topic"
51
 
52
+ # Function to classify query based on wellness topics
53
+ def classify_query(query: str) -> str:
54
+ wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
55
+ if any(keyword in query.lower() for keyword in wellness_keywords):
56
+ return "Wellness"
57
+ class_result = classification_chain.invoke({"query": query})
58
+ classification = class_result.get("text", "").strip()
59
+ return classification if classification != "OutOfScope" else "OutOfScope"
60
+
61
+ # Function to moderate text using Mistral moderation API (async version)
62
+ async def moderate_text(query: str) -> str:
63
  try:
64
+ await pydantic_agent.run(query) # Use async run for Pydantic validation
65
  except Exception as e:
66
  print(f"Error validating text: {e}")
67
  return "Invalid text format."
68
 
69
+ response = await client.classifiers.moderate_chat(
70
  model="mistral-moderation-latest",
71
  inputs=[{"role": "user", "content": query}]
72
  )
 
73
  categories = response['results'][0]['categories']
74
  if categories.get("violence_and_threats", False) or \
75
  categories.get("hate_and_discrimination", False) or \
 
79
 
80
  return query
81
 
82
+ # Use the event loop to run the async functions properly
83
+ async def run_async_pipeline(query: str) -> str:
84
+ # Moderate the query for harmful content (async)
85
+ moderated_query = await moderate_text(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if moderated_query == "OutOfScope":
87
  return "Sorry, this query contains harmful or inappropriate content."
88
 
89
+ # Classify the query manually
90
  classification = classify_query(moderated_query)
 
91
 
92
  if classification == "OutOfScope":
93
  refusal_text = refusal_chain.run({"topic": "this topic"})
 
97
  if classification == "Wellness":
98
  rag_result = wellness_rag_chain({"query": moderated_query})
99
  csv_answer = rag_result["result"].strip()
100
+ web_answer = "" # Empty if we found an answer from the knowledge base
101
  if not csv_answer:
102
+ web_answer = await do_web_search(moderated_query)
 
 
 
 
 
 
103
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
104
  final_answer = tailor_chain.run({"response": final_merged})
105
  return final_answer.strip()
 
115
  final_refusal = tailor_chain.run({"response": refusal_text})
116
  return final_refusal.strip()
117
 
118
+ # Run the pipeline with the event loop
119
+ def run_with_chain(query: str) -> str:
120
+ return asyncio.run(run_async_pipeline(query))