Spaces:
Running
Running
import os | |
import getpass | |
import spacy | |
import pandas as pd | |
from typing import Optional, List, Dict, Any | |
import subprocess | |
from langchain.llms.base import LLM | |
from langchain.docstore.document import Document | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from smolagents import DuckDuckGoSearchTool, ManagedAgent | |
from pydantic import BaseModel, Field, ValidationError, validator | |
from mistralai import Mistral | |
# Import Google Gemini model | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from classification_chain import get_classification_chain | |
from cleaner_chain import get_cleaner_chain | |
from refusal_chain import get_refusal_chain | |
from tailor_chain import get_tailor_chain | |
from prompts import classification_prompt, refusal_prompt, tailor_prompt | |
# Initialize Mistral API client | |
mistral_api_key = os.environ.get("MISTRAL_API_KEY") | |
client = Mistral(api_key=mistral_api_key) | |
# Setup ChatGoogleGenerativeAI for Gemini | |
# Ensure GOOGLE_API_KEY is set in your environment variables. | |
gemini_llm = ChatGoogleGenerativeAI( | |
model="gemini-1.5-pro", | |
temperature=0.5, | |
max_retries=2, | |
google_api_key=os.environ.get("GEMINI_API_KEY"), | |
# Additional parameters or safety_settings can be added here if needed | |
) | |
# Initialize ManagedAgent for web search using Gemini | |
# pydantic_agent = ManagedAgent( | |
# llm=ChatGoogleGenerativeAI( | |
# model="gemini-1.5-pro", | |
# temperature=0.5, | |
# max_retries=2, | |
# google_api_key=os.environ.get("GEMINI_API_KEY"), | |
# ), | |
# tools=[DuckDuckGoSearchTool()] | |
# ) | |
class QueryInput(BaseModel): | |
query: str = Field(..., min_length=1, description="The input query string") | |
def check_query_is_string(cls, v): | |
if not isinstance(v, str): | |
raise ValueError("Query must be a valid string") | |
if v.strip() == "": | |
raise ValueError("Query cannot be empty or just whitespace") | |
return v.strip() | |
class ModerationResult(BaseModel): | |
is_safe: bool = Field(..., description="Whether the content is safe") | |
categories: Dict[str, bool] = Field(default_factory=dict, description="Detected content categories") | |
original_text: str = Field(..., description="The original input text") | |
def install_spacy_model(): | |
try: | |
spacy.load("en_core_web_sm") | |
print("spaCy model 'en_core_web_sm' is already installed.") | |
except OSError: | |
print("Downloading spaCy model 'en_core_web_sm'...") | |
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True) | |
print("spaCy model 'en_core_web_sm' downloaded successfully.") | |
install_spacy_model() | |
nlp = spacy.load("en_core_web_sm") | |
def sanitize_message(message: Any) -> str: | |
"""Sanitize message input to ensure it's a valid string.""" | |
try: | |
if hasattr(message, 'content'): | |
return str(message.content).strip() | |
if isinstance(message, dict) and 'content' in message: | |
return str(message['content']).strip() | |
if isinstance(message, list) and len(message) > 0: | |
if isinstance(message[0], dict) and 'content' in message[0]: | |
return str(message[0]['content']).strip() | |
if hasattr(message[0], 'content'): | |
return str(message[0].content).strip() | |
return str(message).strip() | |
except Exception as e: | |
raise RuntimeError(f"Error in sanitize function: {str(e)}") | |
def extract_main_topic(query: str) -> str: | |
try: | |
query_input = QueryInput(query=query) | |
doc = nlp(query_input.query) | |
main_topic = None | |
for ent in doc.ents: | |
if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]: | |
main_topic = ent.text | |
break | |
if not main_topic: | |
for token in doc: | |
if token.pos_ in ["NOUN", "PROPN"]: | |
main_topic = token.text | |
break | |
return main_topic if main_topic else "this topic" | |
except Exception as e: | |
print(f"Error extracting main topic: {e}") | |
return "this topic" | |
def moderate_text(query: str) -> ModerationResult: | |
try: | |
query_input = QueryInput(query=query) | |
response = client.classifiers.moderate_chat( | |
model="mistral-moderation-latest", | |
inputs=[{"role": "user", "content": query_input.query}] | |
) | |
is_safe = True | |
categories = {} | |
if hasattr(response, 'results') and response.results: | |
categories = { | |
"violence": response.results[0].categories.get("violence_and_threats", False), | |
"hate": response.results[0].categories.get("hate_and_discrimination", False), | |
"dangerous": response.results[0].categories.get("dangerous_and_criminal_content", False), | |
"selfharm": response.results[0].categories.get("selfharm", False) | |
} | |
is_safe = not any(categories.values()) | |
return ModerationResult( | |
is_safe=is_safe, | |
categories=categories, | |
original_text=query_input.query | |
) | |
except ValidationError as e: | |
raise ValueError(f"Input validation failed: {str(e)}") | |
except Exception as e: | |
raise RuntimeError(f"Moderation failed: {str(e)}") | |
def classify_query(query: str) -> str: | |
try: | |
query_input = QueryInput(query=query) | |
wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"] | |
if any(keyword in query_input.query.lower() for keyword in wellness_keywords): | |
return "Wellness" | |
class_result = classification_chain.invoke({"query": query_input.query}) | |
classification = class_result.get("text", "").strip() | |
return classification if classification != "" else "OutOfScope" | |
except ValidationError as e: | |
raise ValueError(f"Classification input validation failed: {str(e)}") | |
except Exception as e: | |
raise RuntimeError(f"Classification failed: {str(e)}") | |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS: | |
try: | |
if os.path.exists(store_dir): | |
print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...") | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") | |
vectorstore = FAISS.load_local(store_dir, embeddings) | |
return vectorstore | |
else: | |
print(f"DEBUG: Building new store from CSV: {csv_path}") | |
df = pd.read_csv(csv_path) | |
df = df.loc[:, ~df.columns.str.contains('^Unnamed')] | |
df.columns = df.columns.str.strip() | |
if "Answer" in df.columns: | |
df.rename(columns={"Answer": "Answers"}, inplace=True) | |
if "Question" not in df.columns and "Question " in df.columns: | |
df.rename(columns={"Question ": "Question"}, inplace=True) | |
if "Question" not in df.columns or "Answers" not in df.columns: | |
raise ValueError("CSV must have 'Question' and 'Answers' columns.") | |
docs = [] | |
for _, row in df.iterrows(): | |
q = str(row["Question"]) | |
ans = str(row["Answers"]) | |
doc = Document(page_content=ans, metadata={"question": q}) | |
docs.append(doc) | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1") | |
vectorstore = FAISS.from_documents(docs, embedding=embeddings) | |
vectorstore.save_local(store_dir) | |
return vectorstore | |
except Exception as e: | |
raise RuntimeError(f"Error building/loading vector store: {str(e)}") | |
def build_rag_chain(vectorstore: FAISS) -> RetrievalQA: | |
"""Build RAG chain using the Gemini LLM directly without a custom class.""" | |
try: | |
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
chain = RetrievalQA.from_chain_type( | |
llm=gemini_llm, # Directly use the ChatGoogleGenerativeAI instance | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True | |
) | |
return chain | |
except Exception as e: | |
raise RuntimeError(f"Error building RAG chain: {str(e)}") | |
def do_web_search(query: str) -> str: | |
try: | |
search_tool = DuckDuckGoSearchTool() | |
search_agent = ManagedAgent(llm=gemini_llm, tools=[search_tool]) | |
search_result = search_agent.run(f"Search for information about: {query}") | |
return str(search_result).strip() | |
except Exception as e: | |
print(f"Web search failed: {e}") | |
return "" | |
def merge_responses(csv_answer: str, web_answer: str) -> str: | |
try: | |
if not csv_answer and not web_answer: | |
return "I apologize, but I couldn't find any relevant information." | |
if not web_answer: | |
return csv_answer | |
if not csv_answer: | |
return web_answer | |
return f"{csv_answer}\n\nAdditional information from web search:\n{web_answer}" | |
except Exception as e: | |
print(f"Error merging responses: {e}") | |
return csv_answer or web_answer or "I apologize, but I couldn't process the information properly." | |
def run_pipeline(query: str) -> str: | |
try: | |
print(query) | |
sanitized_query = sanitize_message(query) | |
query_input = QueryInput(query=sanitized_query) | |
topic = extract_main_topic(query_input.query) | |
moderation_result = moderate_text(query_input.query) | |
if not moderation_result.is_safe: | |
return "Sorry, this query contains harmful or inappropriate content." | |
classification = classify_query(moderation_result.original_text) | |
if classification == "OutOfScope": | |
refusal_text = refusal_chain.run({"topic": topic}) | |
return tailor_chain.run({"response": refusal_text}).strip() | |
if classification == "Wellness": | |
rag_result = wellness_rag_chain({"query": moderation_result.original_text}) | |
if isinstance(rag_result, dict) and "result" in rag_result: | |
csv_answer = str(rag_result["result"]).strip() | |
else: | |
csv_answer = str(rag_result).strip() | |
web_answer = "" if csv_answer else do_web_search(moderation_result.original_text) | |
final_merged = merge_responses(csv_answer, web_answer) | |
return tailor_chain.run({"response": final_merged}).strip() | |
if classification == "Brand": | |
rag_result = brand_rag_chain({"query": moderation_result.original_text}) | |
if isinstance(rag_result, dict) and "result" in rag_result: | |
csv_answer = str(rag_result["result"]).strip() | |
else: | |
csv_answer = str(rag_result).strip() | |
final_merged = merge_responses(csv_answer, "") | |
return tailor_chain.run({"response": final_merged}).strip() | |
refusal_text = refusal_chain.run({"topic": topic}) | |
return tailor_chain.run({"response": refusal_text}).strip() | |
except ValidationError as e: | |
raise ValueError(f"Input validation failed: {str(e)}") | |
except Exception as e: | |
raise RuntimeError(f"Error in run_pipeline: {str(e)}") | |
def run_with_chain(query: str) -> str: | |
try: | |
return run_pipeline(query) | |
except Exception as e: | |
print(f"Error in run_with_chain: {str(e)}") | |
return "I apologize, but I encountered an error processing your request. Please try again." | |
# Initialize chains and vectorstores | |
classification_chain = get_classification_chain() | |
refusal_chain = get_refusal_chain() | |
tailor_chain = get_tailor_chain() | |
cleaner_chain = get_cleaner_chain() | |
wellness_csv = "AIChatbot.csv" | |
brand_csv = "BrandAI.csv" | |
wellness_store_dir = "faiss_wellness_store" | |
brand_store_dir = "faiss_brand_store" | |
wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir) | |
brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir) | |
wellness_rag_chain = build_rag_chain(wellness_vectorstore) | |
brand_rag_chain = build_rag_chain(brand_vectorstore) | |
print("Pipeline initialized successfully!") | |