Spaces:
Sleeping
Sleeping
Update pipeline.py
Browse files- pipeline.py +52 -47
pipeline.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import os
|
2 |
import getpass
|
3 |
-
import spacy
|
4 |
import pandas as pd
|
5 |
from typing import Optional
|
6 |
from langchain.docstore.document import Document
|
@@ -8,17 +8,11 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
|
8 |
from langchain.vectorstores import FAISS
|
9 |
from langchain.chains import RetrievalQA
|
10 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
11 |
-
import subprocess
|
12 |
-
from langchain.llms.base import LLM
|
13 |
-
|
14 |
-
# Import the functions from respective chain files
|
15 |
-
from classification_chain import get_classification_chain
|
16 |
-
from refusal_chain import get_refusal_chain
|
17 |
-
from tailor_chain import get_tailor_chain
|
18 |
-
from cleaner_chain import get_cleaner_chain
|
19 |
|
20 |
# Mistral Client Setup
|
21 |
-
from mistralai import Mistral
|
22 |
from pydantic_ai import Agent # Import Pydantic AI's Agent
|
23 |
|
24 |
# Initialize Mistral API client
|
@@ -28,7 +22,7 @@ client = Mistral(api_key=mistral_api_key)
|
|
28 |
# Initialize Pydantic AI Agent (for text validation)
|
29 |
pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str)
|
30 |
|
31 |
-
# Load spaCy model for NER and download
|
32 |
def install_spacy_model():
|
33 |
try:
|
34 |
spacy.load("en_core_web_sm")
|
@@ -38,38 +32,53 @@ def install_spacy_model():
|
|
38 |
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
|
39 |
print("spaCy model 'en_core_web_sm' downloaded successfully.")
|
40 |
|
41 |
-
# Call the function to install the spaCy model if needed
|
42 |
install_spacy_model()
|
43 |
-
|
44 |
-
# Load the spaCy model globally
|
45 |
nlp = spacy.load("en_core_web_sm")
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
# Function to moderate text using Mistral moderation API
|
48 |
def moderate_text(query: str) -> str:
|
49 |
"""
|
50 |
Classifies the query as harmful or not using Mistral Moderation via Mistral API.
|
51 |
Returns "OutOfScope" if harmful, otherwise returns the original query.
|
52 |
"""
|
53 |
-
# Validate the text type using Pydantic AI's Agent
|
54 |
try:
|
55 |
-
#
|
56 |
-
pydantic_agent.run_sync(query)
|
57 |
except Exception as e:
|
58 |
-
print(f"Error validating text
|
59 |
return "Invalid text format."
|
60 |
|
61 |
-
# Use the moderation API to evaluate if the query is harmful
|
62 |
response = client.classifiers.moderate_chat(
|
63 |
model="mistral-moderation-latest",
|
64 |
-
inputs=[
|
65 |
-
{"role": "user", "content": query},
|
66 |
-
],
|
67 |
)
|
68 |
|
69 |
-
# Extracting category scores from response
|
70 |
categories = response['results'][0]['categories']
|
71 |
-
|
72 |
-
# Check if harmful content is flagged in moderation categories
|
73 |
if categories.get("violence_and_threats", False) or \
|
74 |
categories.get("hate_and_discrimination", False) or \
|
75 |
categories.get("dangerous_and_criminal_content", False) or \
|
@@ -78,7 +87,7 @@ def moderate_text(query: str) -> str:
|
|
78 |
|
79 |
return query
|
80 |
|
81 |
-
#
|
82 |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
83 |
if os.path.exists(store_dir):
|
84 |
print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
|
@@ -107,7 +116,7 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
|
107 |
vectorstore.save_local(store_dir)
|
108 |
return vectorstore
|
109 |
|
110 |
-
#
|
111 |
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
112 |
class GeminiLangChainLLM(LLM):
|
113 |
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
|
@@ -128,13 +137,18 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
|
128 |
)
|
129 |
return rag_chain
|
130 |
|
131 |
-
#
|
132 |
-
classification_chain
|
133 |
-
refusal_chain
|
134 |
-
tailor_chain
|
135 |
-
cleaner_chain
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
#
|
138 |
wellness_csv = "AIChatbot.csv"
|
139 |
brand_csv = "BrandAI.csv"
|
140 |
wellness_store_dir = "faiss_wellness_store"
|
@@ -147,7 +161,7 @@ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("
|
|
147 |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
|
148 |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
|
149 |
|
150 |
-
#
|
151 |
search_tool = DuckDuckGoSearchTool()
|
152 |
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
|
153 |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
|
@@ -159,32 +173,25 @@ def do_web_search(query: str) -> str:
|
|
159 |
response = manager_agent.run(search_query)
|
160 |
return response
|
161 |
|
162 |
-
#
|
163 |
def run_with_chain(query: str) -> str:
|
164 |
print("DEBUG: Starting run_with_chain...")
|
165 |
|
166 |
-
#
|
167 |
moderated_query = moderate_text(query)
|
168 |
if moderated_query == "OutOfScope":
|
169 |
return "Sorry, this query contains harmful or inappropriate content."
|
170 |
|
171 |
-
#
|
172 |
class_result = classification_chain.invoke({"query": moderated_query})
|
173 |
classification = class_result.get("text", "").strip()
|
174 |
print("DEBUG: Classification =>", classification)
|
175 |
|
176 |
-
# If OutOfScope => refusal => tailor => return
|
177 |
if classification == "OutOfScope":
|
178 |
-
|
179 |
-
topic = extract_main_topic(moderated_query)
|
180 |
-
print("DEBUG: Extracted Topic =>", topic)
|
181 |
-
|
182 |
-
# Pass the extracted topic to the refusal chain
|
183 |
-
refusal_text = refusal_chain.run({"topic": topic})
|
184 |
final_refusal = tailor_chain.run({"response": refusal_text})
|
185 |
return final_refusal.strip()
|
186 |
|
187 |
-
# If Wellness => wellness RAG => if insufficient => web => unify => tailor
|
188 |
if classification == "Wellness":
|
189 |
rag_result = wellness_rag_chain({"query": moderated_query})
|
190 |
csv_answer = rag_result["result"].strip()
|
@@ -200,7 +207,6 @@ def run_with_chain(query: str) -> str:
|
|
200 |
final_answer = tailor_chain.run({"response": final_merged})
|
201 |
return final_answer.strip()
|
202 |
|
203 |
-
# If Brand => brand RAG => tailor => return
|
204 |
if classification == "Brand":
|
205 |
rag_result = brand_rag_chain({"query": moderated_query})
|
206 |
csv_answer = rag_result["result"].strip()
|
@@ -208,7 +214,6 @@ def run_with_chain(query: str) -> str:
|
|
208 |
final_answer = tailor_chain.run({"response": final_merged})
|
209 |
return final_answer.strip()
|
210 |
|
211 |
-
# fallback
|
212 |
refusal_text = refusal_chain.run({"topic": "this topic"})
|
213 |
final_refusal = tailor_chain.run({"response": refusal_text})
|
214 |
return final_refusal.strip()
|
|
|
1 |
import os
|
2 |
import getpass
|
3 |
+
import spacy
|
4 |
import pandas as pd
|
5 |
from typing import Optional
|
6 |
from langchain.docstore.document import Document
|
|
|
8 |
from langchain.vectorstores import FAISS
|
9 |
from langchain.chains import RetrievalQA
|
10 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
11 |
+
import subprocess
|
12 |
+
from langchain.llms.base import LLM
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Mistral Client Setup
|
15 |
+
from mistralai import Mistral
|
16 |
from pydantic_ai import Agent # Import Pydantic AI's Agent
|
17 |
|
18 |
# Initialize Mistral API client
|
|
|
22 |
# Initialize Pydantic AI Agent (for text validation)
|
23 |
pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str)
|
24 |
|
25 |
+
# Load spaCy model for NER and download it if not already installed
|
26 |
def install_spacy_model():
|
27 |
try:
|
28 |
spacy.load("en_core_web_sm")
|
|
|
32 |
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
|
33 |
print("spaCy model 'en_core_web_sm' downloaded successfully.")
|
34 |
|
|
|
35 |
install_spacy_model()
|
|
|
|
|
36 |
nlp = spacy.load("en_core_web_sm")
|
37 |
|
38 |
+
# Function to extract the main topic from the query using spaCy NER
|
39 |
+
def extract_main_topic(query: str) -> str:
|
40 |
+
"""
|
41 |
+
Extracts the main topic from the user's query using spaCy's NER.
|
42 |
+
Returns the first named entity or noun found in the query.
|
43 |
+
"""
|
44 |
+
doc = nlp(query)
|
45 |
+
|
46 |
+
# Try to extract the main topic as a named entity (person, product, etc.)
|
47 |
+
main_topic = None
|
48 |
+
for ent in doc.ents:
|
49 |
+
# Filter for specific entity types (you can adjust this based on your needs)
|
50 |
+
if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]: # Add more entity labels as needed
|
51 |
+
main_topic = ent.text
|
52 |
+
break
|
53 |
+
|
54 |
+
# If no named entity found, fallback to extracting the first noun or proper noun
|
55 |
+
if not main_topic:
|
56 |
+
for token in doc:
|
57 |
+
if token.pos_ in ["NOUN", "PROPN"]: # Extract first noun or proper noun
|
58 |
+
main_topic = token.text
|
59 |
+
break
|
60 |
+
|
61 |
+
# Return the extracted topic or a fallback value if no topic is found
|
62 |
+
return main_topic if main_topic else "this topic"
|
63 |
+
|
64 |
# Function to moderate text using Mistral moderation API
|
65 |
def moderate_text(query: str) -> str:
|
66 |
"""
|
67 |
Classifies the query as harmful or not using Mistral Moderation via Mistral API.
|
68 |
Returns "OutOfScope" if harmful, otherwise returns the original query.
|
69 |
"""
|
|
|
70 |
try:
|
71 |
+
pydantic_agent.run_sync(query) # Validate input
|
|
|
72 |
except Exception as e:
|
73 |
+
print(f"Error validating text: {e}")
|
74 |
return "Invalid text format."
|
75 |
|
|
|
76 |
response = client.classifiers.moderate_chat(
|
77 |
model="mistral-moderation-latest",
|
78 |
+
inputs=[{"role": "user", "content": query}]
|
|
|
|
|
79 |
)
|
80 |
|
|
|
81 |
categories = response['results'][0]['categories']
|
|
|
|
|
82 |
if categories.get("violence_and_threats", False) or \
|
83 |
categories.get("hate_and_discrimination", False) or \
|
84 |
categories.get("dangerous_and_criminal_content", False) or \
|
|
|
87 |
|
88 |
return query
|
89 |
|
90 |
+
# Build or load vectorstore function
|
91 |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
92 |
if os.path.exists(store_dir):
|
93 |
print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
|
|
|
116 |
vectorstore.save_local(store_dir)
|
117 |
return vectorstore
|
118 |
|
119 |
+
# Build RAG chain for Gemini (no changes)
|
120 |
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
121 |
class GeminiLangChainLLM(LLM):
|
122 |
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
|
|
|
137 |
)
|
138 |
return rag_chain
|
139 |
|
140 |
+
# Initialize all the separate chains
|
141 |
+
from classification_chain import get_classification_chain
|
142 |
+
from refusal_chain import get_refusal_chain
|
143 |
+
from tailor_chain import get_tailor_chain
|
144 |
+
from cleaner_chain import get_cleaner_chain
|
145 |
+
|
146 |
+
classification_chain = get_classification_chain() # Ensure this function is imported correctly
|
147 |
+
refusal_chain = get_refusal_chain() # Ensure this function is imported correctly
|
148 |
+
tailor_chain = get_tailor_chain() # Ensure this function is imported correctly
|
149 |
+
cleaner_chain = get_cleaner_chain() # Ensure this function is imported correctly
|
150 |
|
151 |
+
# Build our vectorstores + RAG chains
|
152 |
wellness_csv = "AIChatbot.csv"
|
153 |
brand_csv = "BrandAI.csv"
|
154 |
wellness_store_dir = "faiss_wellness_store"
|
|
|
161 |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
|
162 |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
|
163 |
|
164 |
+
# Tools / Agents for web search
|
165 |
search_tool = DuckDuckGoSearchTool()
|
166 |
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
|
167 |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
|
|
|
173 |
response = manager_agent.run(search_query)
|
174 |
return response
|
175 |
|
176 |
+
# Orchestrator: run_with_chain
|
177 |
def run_with_chain(query: str) -> str:
|
178 |
print("DEBUG: Starting run_with_chain...")
|
179 |
|
180 |
+
# Moderate the query for harmful content
|
181 |
moderated_query = moderate_text(query)
|
182 |
if moderated_query == "OutOfScope":
|
183 |
return "Sorry, this query contains harmful or inappropriate content."
|
184 |
|
185 |
+
# Classify the query
|
186 |
class_result = classification_chain.invoke({"query": moderated_query})
|
187 |
classification = class_result.get("text", "").strip()
|
188 |
print("DEBUG: Classification =>", classification)
|
189 |
|
|
|
190 |
if classification == "OutOfScope":
|
191 |
+
refusal_text = refusal_chain.run({"topic": "this topic"})
|
|
|
|
|
|
|
|
|
|
|
192 |
final_refusal = tailor_chain.run({"response": refusal_text})
|
193 |
return final_refusal.strip()
|
194 |
|
|
|
195 |
if classification == "Wellness":
|
196 |
rag_result = wellness_rag_chain({"query": moderated_query})
|
197 |
csv_answer = rag_result["result"].strip()
|
|
|
207 |
final_answer = tailor_chain.run({"response": final_merged})
|
208 |
return final_answer.strip()
|
209 |
|
|
|
210 |
if classification == "Brand":
|
211 |
rag_result = brand_rag_chain({"query": moderated_query})
|
212 |
csv_answer = rag_result["result"].strip()
|
|
|
214 |
final_answer = tailor_chain.run({"response": final_merged})
|
215 |
return final_answer.strip()
|
216 |
|
|
|
217 |
refusal_text = refusal_chain.run({"topic": "this topic"})
|
218 |
final_refusal = tailor_chain.run({"response": refusal_text})
|
219 |
return final_refusal.strip()
|