Phoenix21 commited on
Commit
74221f2
·
verified ·
1 Parent(s): 774c0b8

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +52 -47
pipeline.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import getpass
3
- import spacy # Import spaCy for NER functionality
4
  import pandas as pd
5
  from typing import Optional
6
  from langchain.docstore.document import Document
@@ -8,17 +8,11 @@ from langchain.embeddings import HuggingFaceEmbeddings
8
  from langchain.vectorstores import FAISS
9
  from langchain.chains import RetrievalQA
10
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
11
- import subprocess # Import subprocess to run shell commands
12
- from langchain.llms.base import LLM # Import LLM
13
-
14
- # Import the functions from respective chain files
15
- from classification_chain import get_classification_chain
16
- from refusal_chain import get_refusal_chain
17
- from tailor_chain import get_tailor_chain
18
- from cleaner_chain import get_cleaner_chain
19
 
20
  # Mistral Client Setup
21
- from mistralai import Mistral # Import the Mistral client
22
  from pydantic_ai import Agent # Import Pydantic AI's Agent
23
 
24
  # Initialize Mistral API client
@@ -28,7 +22,7 @@ client = Mistral(api_key=mistral_api_key)
28
  # Initialize Pydantic AI Agent (for text validation)
29
  pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str)
30
 
31
- # Load spaCy model for NER and download the spaCy model if not already installed
32
  def install_spacy_model():
33
  try:
34
  spacy.load("en_core_web_sm")
@@ -38,38 +32,53 @@ def install_spacy_model():
38
  subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
39
  print("spaCy model 'en_core_web_sm' downloaded successfully.")
40
 
41
- # Call the function to install the spaCy model if needed
42
  install_spacy_model()
43
-
44
- # Load the spaCy model globally
45
  nlp = spacy.load("en_core_web_sm")
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Function to moderate text using Mistral moderation API
48
  def moderate_text(query: str) -> str:
49
  """
50
  Classifies the query as harmful or not using Mistral Moderation via Mistral API.
51
  Returns "OutOfScope" if harmful, otherwise returns the original query.
52
  """
53
- # Validate the text type using Pydantic AI's Agent
54
  try:
55
- # Use Pydantic AI agent to ensure correct text type
56
- pydantic_agent.run_sync(query)
57
  except Exception as e:
58
- print(f"Error validating text with Pydantic AI: {e}")
59
  return "Invalid text format."
60
 
61
- # Use the moderation API to evaluate if the query is harmful
62
  response = client.classifiers.moderate_chat(
63
  model="mistral-moderation-latest",
64
- inputs=[
65
- {"role": "user", "content": query},
66
- ],
67
  )
68
 
69
- # Extracting category scores from response
70
  categories = response['results'][0]['categories']
71
-
72
- # Check if harmful content is flagged in moderation categories
73
  if categories.get("violence_and_threats", False) or \
74
  categories.get("hate_and_discrimination", False) or \
75
  categories.get("dangerous_and_criminal_content", False) or \
@@ -78,7 +87,7 @@ def moderate_text(query: str) -> str:
78
 
79
  return query
80
 
81
- # 3) build_or_load_vectorstore (no changes)
82
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
83
  if os.path.exists(store_dir):
84
  print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
@@ -107,7 +116,7 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
107
  vectorstore.save_local(store_dir)
108
  return vectorstore
109
 
110
- # 4) Build RAG chain for Gemini (no changes)
111
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
112
  class GeminiLangChainLLM(LLM):
113
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
@@ -128,13 +137,18 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
128
  )
129
  return rag_chain
130
 
131
- # 5) Initialize all the separate chains
132
- classification_chain = get_classification_chain()
133
- refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic
134
- tailor_chain = get_tailor_chain()
135
- cleaner_chain = get_cleaner_chain()
 
 
 
 
 
136
 
137
- # 6) Build our vectorstores + RAG chains
138
  wellness_csv = "AIChatbot.csv"
139
  brand_csv = "BrandAI.csv"
140
  wellness_store_dir = "faiss_wellness_store"
@@ -147,7 +161,7 @@ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("
147
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
148
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
149
 
150
- # 7) Tools / Agents for web search (no changes)
151
  search_tool = DuckDuckGoSearchTool()
152
  web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
153
  managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
@@ -159,32 +173,25 @@ def do_web_search(query: str) -> str:
159
  response = manager_agent.run(search_query)
160
  return response
161
 
162
- # 8) Orchestrator: run_with_chain
163
  def run_with_chain(query: str) -> str:
164
  print("DEBUG: Starting run_with_chain...")
165
 
166
- # 1) Moderate the query for harmful content
167
  moderated_query = moderate_text(query)
168
  if moderated_query == "OutOfScope":
169
  return "Sorry, this query contains harmful or inappropriate content."
170
 
171
- # 2) Classify the query
172
  class_result = classification_chain.invoke({"query": moderated_query})
173
  classification = class_result.get("text", "").strip()
174
  print("DEBUG: Classification =>", classification)
175
 
176
- # If OutOfScope => refusal => tailor => return
177
  if classification == "OutOfScope":
178
- # Extract the main topic for the refusal message
179
- topic = extract_main_topic(moderated_query)
180
- print("DEBUG: Extracted Topic =>", topic)
181
-
182
- # Pass the extracted topic to the refusal chain
183
- refusal_text = refusal_chain.run({"topic": topic})
184
  final_refusal = tailor_chain.run({"response": refusal_text})
185
  return final_refusal.strip()
186
 
187
- # If Wellness => wellness RAG => if insufficient => web => unify => tailor
188
  if classification == "Wellness":
189
  rag_result = wellness_rag_chain({"query": moderated_query})
190
  csv_answer = rag_result["result"].strip()
@@ -200,7 +207,6 @@ def run_with_chain(query: str) -> str:
200
  final_answer = tailor_chain.run({"response": final_merged})
201
  return final_answer.strip()
202
 
203
- # If Brand => brand RAG => tailor => return
204
  if classification == "Brand":
205
  rag_result = brand_rag_chain({"query": moderated_query})
206
  csv_answer = rag_result["result"].strip()
@@ -208,7 +214,6 @@ def run_with_chain(query: str) -> str:
208
  final_answer = tailor_chain.run({"response": final_merged})
209
  return final_answer.strip()
210
 
211
- # fallback
212
  refusal_text = refusal_chain.run({"topic": "this topic"})
213
  final_refusal = tailor_chain.run({"response": refusal_text})
214
  return final_refusal.strip()
 
1
  import os
2
  import getpass
3
+ import spacy
4
  import pandas as pd
5
  from typing import Optional
6
  from langchain.docstore.document import Document
 
8
  from langchain.vectorstores import FAISS
9
  from langchain.chains import RetrievalQA
10
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
11
+ import subprocess
12
+ from langchain.llms.base import LLM
 
 
 
 
 
 
13
 
14
  # Mistral Client Setup
15
+ from mistralai import Mistral
16
  from pydantic_ai import Agent # Import Pydantic AI's Agent
17
 
18
  # Initialize Mistral API client
 
22
  # Initialize Pydantic AI Agent (for text validation)
23
  pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str)
24
 
25
+ # Load spaCy model for NER and download it if not already installed
26
  def install_spacy_model():
27
  try:
28
  spacy.load("en_core_web_sm")
 
32
  subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
33
  print("spaCy model 'en_core_web_sm' downloaded successfully.")
34
 
 
35
  install_spacy_model()
 
 
36
  nlp = spacy.load("en_core_web_sm")
37
 
38
+ # Function to extract the main topic from the query using spaCy NER
39
+ def extract_main_topic(query: str) -> str:
40
+ """
41
+ Extracts the main topic from the user's query using spaCy's NER.
42
+ Returns the first named entity or noun found in the query.
43
+ """
44
+ doc = nlp(query)
45
+
46
+ # Try to extract the main topic as a named entity (person, product, etc.)
47
+ main_topic = None
48
+ for ent in doc.ents:
49
+ # Filter for specific entity types (you can adjust this based on your needs)
50
+ if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]: # Add more entity labels as needed
51
+ main_topic = ent.text
52
+ break
53
+
54
+ # If no named entity found, fallback to extracting the first noun or proper noun
55
+ if not main_topic:
56
+ for token in doc:
57
+ if token.pos_ in ["NOUN", "PROPN"]: # Extract first noun or proper noun
58
+ main_topic = token.text
59
+ break
60
+
61
+ # Return the extracted topic or a fallback value if no topic is found
62
+ return main_topic if main_topic else "this topic"
63
+
64
  # Function to moderate text using Mistral moderation API
65
  def moderate_text(query: str) -> str:
66
  """
67
  Classifies the query as harmful or not using Mistral Moderation via Mistral API.
68
  Returns "OutOfScope" if harmful, otherwise returns the original query.
69
  """
 
70
  try:
71
+ pydantic_agent.run_sync(query) # Validate input
 
72
  except Exception as e:
73
+ print(f"Error validating text: {e}")
74
  return "Invalid text format."
75
 
 
76
  response = client.classifiers.moderate_chat(
77
  model="mistral-moderation-latest",
78
+ inputs=[{"role": "user", "content": query}]
 
 
79
  )
80
 
 
81
  categories = response['results'][0]['categories']
 
 
82
  if categories.get("violence_and_threats", False) or \
83
  categories.get("hate_and_discrimination", False) or \
84
  categories.get("dangerous_and_criminal_content", False) or \
 
87
 
88
  return query
89
 
90
+ # Build or load vectorstore function
91
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
92
  if os.path.exists(store_dir):
93
  print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
 
116
  vectorstore.save_local(store_dir)
117
  return vectorstore
118
 
119
+ # Build RAG chain for Gemini (no changes)
120
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
121
  class GeminiLangChainLLM(LLM):
122
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
 
137
  )
138
  return rag_chain
139
 
140
+ # Initialize all the separate chains
141
+ from classification_chain import get_classification_chain
142
+ from refusal_chain import get_refusal_chain
143
+ from tailor_chain import get_tailor_chain
144
+ from cleaner_chain import get_cleaner_chain
145
+
146
+ classification_chain = get_classification_chain() # Ensure this function is imported correctly
147
+ refusal_chain = get_refusal_chain() # Ensure this function is imported correctly
148
+ tailor_chain = get_tailor_chain() # Ensure this function is imported correctly
149
+ cleaner_chain = get_cleaner_chain() # Ensure this function is imported correctly
150
 
151
+ # Build our vectorstores + RAG chains
152
  wellness_csv = "AIChatbot.csv"
153
  brand_csv = "BrandAI.csv"
154
  wellness_store_dir = "faiss_wellness_store"
 
161
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
162
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
163
 
164
+ # Tools / Agents for web search
165
  search_tool = DuckDuckGoSearchTool()
166
  web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
167
  managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
 
173
  response = manager_agent.run(search_query)
174
  return response
175
 
176
+ # Orchestrator: run_with_chain
177
  def run_with_chain(query: str) -> str:
178
  print("DEBUG: Starting run_with_chain...")
179
 
180
+ # Moderate the query for harmful content
181
  moderated_query = moderate_text(query)
182
  if moderated_query == "OutOfScope":
183
  return "Sorry, this query contains harmful or inappropriate content."
184
 
185
+ # Classify the query
186
  class_result = classification_chain.invoke({"query": moderated_query})
187
  classification = class_result.get("text", "").strip()
188
  print("DEBUG: Classification =>", classification)
189
 
 
190
  if classification == "OutOfScope":
191
+ refusal_text = refusal_chain.run({"topic": "this topic"})
 
 
 
 
 
192
  final_refusal = tailor_chain.run({"response": refusal_text})
193
  return final_refusal.strip()
194
 
 
195
  if classification == "Wellness":
196
  rag_result = wellness_rag_chain({"query": moderated_query})
197
  csv_answer = rag_result["result"].strip()
 
207
  final_answer = tailor_chain.run({"response": final_merged})
208
  return final_answer.strip()
209
 
 
210
  if classification == "Brand":
211
  rag_result = brand_rag_chain({"query": moderated_query})
212
  csv_answer = rag_result["result"].strip()
 
214
  final_answer = tailor_chain.run({"response": final_merged})
215
  return final_answer.strip()
216
 
 
217
  refusal_text = refusal_chain.run({"topic": "this topic"})
218
  final_refusal = tailor_chain.run({"response": refusal_text})
219
  return final_refusal.strip()