Spaces:
Sleeping
Sleeping
import os | |
import getpass | |
from pydantic_ai import Agent # Import the Agent from pydantic_ai | |
from pydantic_ai.models.mistral import MistralModel # Import the Mistral model | |
import spacy # Import spaCy for NER functionality | |
import pandas as pd | |
from typing import Optional | |
from langchain.docstore.document import Document | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel | |
import subprocess # Import subprocess to run shell commands | |
from langchain.llms.base import LLM # Import LLM | |
# Initialize Mistral agent using Pydantic AI | |
mistral_api_key = os.environ.get("MISTRAL_API_KEY") # Ensure your Mistral API key is set | |
mistral_model = MistralModel("mistral-large-latest", api_key=mistral_api_key) # Use a Mistral model | |
mistral_agent = Agent(mistral_model) | |
# Load spaCy model for NER and download the spaCy model if not already installed | |
def install_spacy_model(): | |
try: | |
spacy.load("en_core_web_sm") | |
print("spaCy model 'en_core_web_sm' is already installed.") | |
except OSError: | |
print("Downloading spaCy model 'en_core_web_sm'...") | |
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True) | |
print("spaCy model 'en_core_web_sm' downloaded successfully.") | |
# Call the function to install the spaCy model if needed | |
install_spacy_model() | |
# Load the spaCy model globally | |
nlp = spacy.load("en_core_web_sm") | |
# Function to moderate text using Pydantic AI's Mistral moderation model | |
def moderate_text(query: str) -> str: | |
""" | |
Classifies the query as harmful or not using Mistral Moderation via Pydantic AI. | |
Returns "OutOfScope" if harmful, otherwise returns the original query. | |
""" | |
response = mistral_agent.call("classify", {"inputs": [query]}) | |
categories = response['results'][0]['categories'] | |
# Check if harmful content is flagged in moderation categories | |
if categories.get("violence_and_threats", False) or \ | |
categories.get("hate_and_discrimination", False) or \ | |
categories.get("dangerous_and_criminal_content", False) or \ | |
categories.get("selfharm", False): | |
return "OutOfScope" | |
return query | |
# 3) build_or_load_vectorstore (no changes) | |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS: | |
if os.path.exists(store_dir): | |
print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...") | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") | |
vectorstore = FAISS.load_local(store_dir, embeddings) | |
return vectorstore | |
else: | |
print(f"DEBUG: Building new store from CSV: {csv_path}") | |
df = pd.read_csv(csv_path) | |
df = df.loc[:, ~df.columns.str.contains('^Unnamed')] | |
df.columns = df.columns.str.strip() | |
if "Answer" in df.columns: | |
df.rename(columns={"Answer": "Answers"}, inplace=True) | |
if "Question" not in df.columns and "Question " in df.columns: | |
df.rename(columns={"Question ": "Question"}, inplace=True) | |
if "Question" not in df.columns or "Answers" not in df.columns: | |
raise ValueError("CSV must have 'Question' and 'Answers' columns.") | |
docs = [] | |
for _, row in df.iterrows(): | |
q = str(row["Question"]) | |
ans = str(row["Answers"]) | |
doc = Document(page_content=ans, metadata={"question": q}) | |
docs.append(doc) | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") | |
vectorstore = FAISS.from_documents(docs, embedding=embeddings) | |
vectorstore.save_local(store_dir) | |
return vectorstore | |
# 4) Build RAG chain for Gemini (no changes) | |
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA: | |
class GeminiLangChainLLM(LLM): | |
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str: | |
messages = [{"role": "user", "content": prompt}] | |
return llm_model(messages, stop_sequences=stop) | |
def _llm_type(self) -> str: | |
return "custom_gemini" | |
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
gemini_as_llm = GeminiLangChainLLM() | |
rag_chain = RetrievalQA.from_chain_type( | |
llm=gemini_as_llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True | |
) | |
return rag_chain | |
# 5) Initialize all the separate chains | |
classification_chain = get_classification_chain() | |
refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic | |
tailor_chain = get_tailor_chain() | |
cleaner_chain = get_cleaner_chain() | |
# 6) Build our vectorstores + RAG chains | |
wellness_csv = "AIChatbot.csv" | |
brand_csv = "BrandAI.csv" | |
wellness_store_dir = "faiss_wellness_store" | |
brand_store_dir = "faiss_brand_store" | |
wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir) | |
brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir) | |
gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY")) | |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore) | |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore) | |
# 7) Tools / Agents for web search (no changes) | |
search_tool = DuckDuckGoSearchTool() | |
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm) | |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.") | |
manager_agent = CodeAgent(tools=[], model=gemini_llm, managed_agents=[managed_web_agent]) | |
def do_web_search(query: str) -> str: | |
print("DEBUG: Attempting web search for more info...") | |
search_query = f"Give me relevant info: {query}" | |
response = manager_agent.run(search_query) | |
return response | |
# 8) Orchestrator: run_with_chain | |
def run_with_chain(query: str) -> str: | |
print("DEBUG: Starting run_with_chain...") | |
# 1) Moderate the query for harmful content | |
moderated_query = moderate_text(query) | |
if moderated_query == "OutOfScope": | |
return "Sorry, this query contains harmful or inappropriate content." | |
# 2) Classify the query | |
class_result = classification_chain.invoke({"query": moderated_query}) | |
classification = class_result.get("text", "").strip() | |
print("DEBUG: Classification =>", classification) | |
# If OutOfScope => refusal => tailor => return | |
if classification == "OutOfScope": | |
# Extract the main topic for the refusal message | |
topic = extract_main_topic(moderated_query) | |
print("DEBUG: Extracted Topic =>", topic) | |
# Pass the extracted topic to the refusal chain | |
refusal_text = refusal_chain.run({"topic": topic}) | |
final_refusal = tailor_chain.run({"response": refusal_text}) | |
return final_refusal.strip() | |
# If Wellness => wellness RAG => if insufficient => web => unify => tailor | |
if classification == "Wellness": | |
rag_result = wellness_rag_chain({"query": moderated_query}) | |
csv_answer = rag_result["result"].strip() | |
if not csv_answer: | |
web_answer = do_web_search(moderated_query) | |
else: | |
lower_ans = csv_answer.lower() | |
if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]): | |
web_answer = do_web_search(moderated_query) | |
else: | |
web_answer = "" | |
final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer) | |
final_answer = tailor_chain.run({"response": final_merged}) | |
return final_answer.strip() | |
# If Brand => brand RAG => tailor => return | |
if classification == "Brand": | |
rag_result = brand_rag_chain({"query": moderated_query}) | |
csv_answer = rag_result["result"].strip() | |
final_merged = cleaner_chain.merge(kb=csv_answer, web="") | |
final_answer = tailor_chain.run({"response": final_merged}) | |
return final_answer.strip() | |
# fallback | |
refusal_text = refusal_chain.run({"topic": "this topic"}) | |
final_refusal = tailor_chain.run({"response": refusal_text}) | |
return final_refusal.strip() | |