Spaces:
Running
Running
Update pipeline.py
Browse files- pipeline.py +21 -68
pipeline.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import os
|
2 |
import getpass
|
|
|
|
|
3 |
import spacy # Import spaCy for NER functionality
|
4 |
import pandas as pd
|
5 |
from typing import Optional
|
@@ -10,33 +12,18 @@ from langchain.chains import RetrievalQA
|
|
10 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
11 |
import subprocess # Import subprocess to run shell commands
|
12 |
from langchain.llms.base import LLM # Import LLM
|
13 |
-
|
14 |
-
|
15 |
-
#
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
# 1) Environment: set up keys if missing
|
22 |
-
if not os.environ.get("GEMINI_API_KEY"):
|
23 |
-
os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
|
24 |
-
if not os.environ.get("GROQ_API_KEY"):
|
25 |
-
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
|
26 |
-
if not os.environ.get("MISTRAL_API_KEY"):
|
27 |
-
os.environ["MISTRAL_API_KEY"] = getpass.getpass("Enter your Mistral API Key: ")
|
28 |
-
|
29 |
-
# Initialize Mistral client
|
30 |
-
mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
|
31 |
-
|
32 |
-
# 2) Load spaCy model for NER and download the spaCy model if not already installed
|
33 |
def install_spacy_model():
|
34 |
try:
|
35 |
-
# Check if the model is already installed
|
36 |
spacy.load("en_core_web_sm")
|
37 |
print("spaCy model 'en_core_web_sm' is already installed.")
|
38 |
except OSError:
|
39 |
-
# If model is not installed, download it using subprocess
|
40 |
print("Downloading spaCy model 'en_core_web_sm'...")
|
41 |
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
|
42 |
print("spaCy model 'en_core_web_sm' downloaded successfully.")
|
@@ -47,46 +34,16 @@ install_spacy_model()
|
|
47 |
# Load the spaCy model globally
|
48 |
nlp = spacy.load("en_core_web_sm")
|
49 |
|
50 |
-
# Function to
|
51 |
-
def extract_main_topic(query: str) -> str:
|
52 |
-
"""
|
53 |
-
Extracts the main topic from the user's query using spaCy's NER.
|
54 |
-
Returns the first named entity or noun found in the query.
|
55 |
-
"""
|
56 |
-
doc = nlp(query) # Use the globally loaded spaCy model
|
57 |
-
|
58 |
-
# Try to extract the main topic as a named entity (person, product, etc.)
|
59 |
-
main_topic = None
|
60 |
-
for ent in doc.ents:
|
61 |
-
# Filter for specific entity types (you can adjust this based on your needs)
|
62 |
-
if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]: # Add more entity labels as needed
|
63 |
-
main_topic = ent.text
|
64 |
-
break
|
65 |
-
|
66 |
-
# If no named entity found, fallback to extracting the first noun or proper noun
|
67 |
-
if not main_topic:
|
68 |
-
for token in doc:
|
69 |
-
if token.pos_ in ["NOUN", "PROPN"]: # Extract first noun or proper noun
|
70 |
-
main_topic = token.text
|
71 |
-
break
|
72 |
-
|
73 |
-
# Return the extracted topic or a fallback value if no topic is found
|
74 |
-
return main_topic if main_topic else "this topic"
|
75 |
-
|
76 |
-
# 3) Function to moderate text using Mistral moderation API
|
77 |
def moderate_text(query: str) -> str:
|
78 |
"""
|
79 |
-
Classifies the query as harmful or not using Mistral Moderation
|
80 |
Returns "OutOfScope" if harmful, otherwise returns the original query.
|
81 |
"""
|
82 |
-
response =
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
categories = response.results[0].categories
|
88 |
-
|
89 |
-
# Check if any harmful category is flagged
|
90 |
if categories.get("violence_and_threats", False) or \
|
91 |
categories.get("hate_and_discrimination", False) or \
|
92 |
categories.get("dangerous_and_criminal_content", False) or \
|
@@ -94,7 +51,7 @@ def moderate_text(query: str) -> str:
|
|
94 |
return "OutOfScope"
|
95 |
return query
|
96 |
|
97 |
-
#
|
98 |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
99 |
if os.path.exists(store_dir):
|
100 |
print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
|
@@ -123,7 +80,7 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
|
123 |
vectorstore.save_local(store_dir)
|
124 |
return vectorstore
|
125 |
|
126 |
-
#
|
127 |
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
128 |
class GeminiLangChainLLM(LLM):
|
129 |
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
|
@@ -144,13 +101,13 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
|
144 |
)
|
145 |
return rag_chain
|
146 |
|
147 |
-
#
|
148 |
classification_chain = get_classification_chain()
|
149 |
refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic
|
150 |
tailor_chain = get_tailor_chain()
|
151 |
cleaner_chain = get_cleaner_chain()
|
152 |
|
153 |
-
#
|
154 |
wellness_csv = "AIChatbot.csv"
|
155 |
brand_csv = "BrandAI.csv"
|
156 |
wellness_store_dir = "faiss_wellness_store"
|
@@ -163,7 +120,7 @@ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("
|
|
163 |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
|
164 |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
|
165 |
|
166 |
-
#
|
167 |
search_tool = DuckDuckGoSearchTool()
|
168 |
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
|
169 |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
|
@@ -175,14 +132,10 @@ def do_web_search(query: str) -> str:
|
|
175 |
response = manager_agent.run(search_query)
|
176 |
return response
|
177 |
|
178 |
-
#
|
179 |
def run_with_chain(query: str) -> str:
|
180 |
print("DEBUG: Starting run_with_chain...")
|
181 |
|
182 |
-
|
183 |
-
# Ensure the query is a string
|
184 |
-
query = str(query).strip()
|
185 |
-
|
186 |
# 1) Moderate the query for harmful content
|
187 |
moderated_query = moderate_text(query)
|
188 |
if moderated_query == "OutOfScope":
|
|
|
1 |
import os
|
2 |
import getpass
|
3 |
+
from pydantic_ai import Agent # Import the Agent from pydantic_ai
|
4 |
+
from pydantic_ai.models.mistral import MistralModel # Import the Mistral model
|
5 |
import spacy # Import spaCy for NER functionality
|
6 |
import pandas as pd
|
7 |
from typing import Optional
|
|
|
12 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
13 |
import subprocess # Import subprocess to run shell commands
|
14 |
from langchain.llms.base import LLM # Import LLM
|
15 |
+
|
16 |
+
# Initialize Mistral agent using Pydantic AI
|
17 |
+
mistral_api_key = os.environ.get("MISTRAL_API_KEY") # Ensure your Mistral API key is set
|
18 |
+
mistral_model = MistralModel("mistral-large-latest", api_key=mistral_api_key) # Use a Mistral model
|
19 |
+
mistral_agent = Agent(mistral_model)
|
20 |
+
|
21 |
+
# Load spaCy model for NER and download the spaCy model if not already installed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def install_spacy_model():
|
23 |
try:
|
|
|
24 |
spacy.load("en_core_web_sm")
|
25 |
print("spaCy model 'en_core_web_sm' is already installed.")
|
26 |
except OSError:
|
|
|
27 |
print("Downloading spaCy model 'en_core_web_sm'...")
|
28 |
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
|
29 |
print("spaCy model 'en_core_web_sm' downloaded successfully.")
|
|
|
34 |
# Load the spaCy model globally
|
35 |
nlp = spacy.load("en_core_web_sm")
|
36 |
|
37 |
+
# Function to moderate text using Pydantic AI's Mistral moderation model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
def moderate_text(query: str) -> str:
|
39 |
"""
|
40 |
+
Classifies the query as harmful or not using Mistral Moderation via Pydantic AI.
|
41 |
Returns "OutOfScope" if harmful, otherwise returns the original query.
|
42 |
"""
|
43 |
+
response = mistral_agent.call("classify", {"inputs": [query]})
|
44 |
+
categories = response['results'][0]['categories']
|
45 |
+
|
46 |
+
# Check if harmful content is flagged in moderation categories
|
|
|
|
|
|
|
|
|
47 |
if categories.get("violence_and_threats", False) or \
|
48 |
categories.get("hate_and_discrimination", False) or \
|
49 |
categories.get("dangerous_and_criminal_content", False) or \
|
|
|
51 |
return "OutOfScope"
|
52 |
return query
|
53 |
|
54 |
+
# 3) build_or_load_vectorstore (no changes)
|
55 |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
56 |
if os.path.exists(store_dir):
|
57 |
print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
|
|
|
80 |
vectorstore.save_local(store_dir)
|
81 |
return vectorstore
|
82 |
|
83 |
+
# 4) Build RAG chain for Gemini (no changes)
|
84 |
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
85 |
class GeminiLangChainLLM(LLM):
|
86 |
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
|
|
|
101 |
)
|
102 |
return rag_chain
|
103 |
|
104 |
+
# 5) Initialize all the separate chains
|
105 |
classification_chain = get_classification_chain()
|
106 |
refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic
|
107 |
tailor_chain = get_tailor_chain()
|
108 |
cleaner_chain = get_cleaner_chain()
|
109 |
|
110 |
+
# 6) Build our vectorstores + RAG chains
|
111 |
wellness_csv = "AIChatbot.csv"
|
112 |
brand_csv = "BrandAI.csv"
|
113 |
wellness_store_dir = "faiss_wellness_store"
|
|
|
120 |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
|
121 |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
|
122 |
|
123 |
+
# 7) Tools / Agents for web search (no changes)
|
124 |
search_tool = DuckDuckGoSearchTool()
|
125 |
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
|
126 |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
|
|
|
132 |
response = manager_agent.run(search_query)
|
133 |
return response
|
134 |
|
135 |
+
# 8) Orchestrator: run_with_chain
|
136 |
def run_with_chain(query: str) -> str:
|
137 |
print("DEBUG: Starting run_with_chain...")
|
138 |
|
|
|
|
|
|
|
|
|
139 |
# 1) Moderate the query for harmful content
|
140 |
moderated_query = moderate_text(query)
|
141 |
if moderated_query == "OutOfScope":
|