Spaces:
Sleeping
Sleeping
import os | |
import getpass | |
import spacy | |
import pandas as pd | |
from typing import Optional | |
import subprocess | |
from langchain.llms.base import LLM | |
from langchain.docstore.document import Document | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel | |
from pydantic_ai import Agent # Import Pydantic AI's Agent | |
from mistralai import Mistral | |
import asyncio # Needed for managing async tasks | |
# Initialize Mistral API client | |
mistral_api_key = os.environ.get("MISTRAL_API_KEY") | |
client = Mistral(api_key=mistral_api_key) | |
# Initialize Pydantic AI Agent (for text validation) | |
pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str) | |
# Load spaCy model for NER and download it if not already installed | |
def install_spacy_model(): | |
try: | |
spacy.load("en_core_web_sm") | |
print("spaCy model 'en_core_web_sm' is already installed.") | |
except OSError: | |
print("Downloading spaCy model 'en_core_web_sm'...") | |
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True) | |
print("spaCy model 'en_core_web_sm' downloaded successfully.") | |
install_spacy_model() | |
nlp = spacy.load("en_core_web_sm") | |
# Function to extract the main topic from the query using spaCy NER | |
def extract_main_topic(query: str) -> str: | |
doc = nlp(query) | |
main_topic = None | |
for ent in doc.ents: | |
if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]: | |
main_topic = ent.text | |
break | |
if not main_topic: | |
for token in doc: | |
if token.pos_ in ["NOUN", "PROPN"]: | |
main_topic = token.text | |
break | |
return main_topic if main_topic else "this topic" | |
# Function to classify query based on wellness topics | |
def classify_query(query: str) -> str: | |
wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"] | |
if any(keyword in query.lower() for keyword in wellness_keywords): | |
return "Wellness" | |
class_result = classification_chain.invoke({"query": query}) | |
classification = class_result.get("text", "").strip() | |
return classification if classification != "OutOfScope" else "OutOfScope" | |
# Function to moderate text using Mistral moderation API (async version) | |
async def moderate_text(query: str) -> str: | |
try: | |
await pydantic_agent.run(query) # Use async run for Pydantic validation | |
except Exception as e: | |
print(f"Error validating text: {e}") | |
return "Invalid text format." | |
response = await client.classifiers.moderate_chat( | |
model="mistral-moderation-latest", | |
inputs=[{"role": "user", "content": query}] | |
) | |
categories = response['results'][0]['categories'] | |
if categories.get("violence_and_threats", False) or \ | |
categories.get("hate_and_discrimination", False) or \ | |
categories.get("dangerous_and_criminal_content", False) or \ | |
categories.get("selfharm", False): | |
return "OutOfScope" | |
return query | |
# Use the event loop to run the async functions properly | |
async def run_async_pipeline(query: str) -> str: | |
# Moderate the query for harmful content (async) | |
moderated_query = await moderate_text(query) | |
if moderated_query == "OutOfScope": | |
return "Sorry, this query contains harmful or inappropriate content." | |
# Classify the query manually | |
classification = classify_query(moderated_query) | |
if classification == "OutOfScope": | |
refusal_text = refusal_chain.run({"topic": "this topic"}) | |
final_refusal = tailor_chain.run({"response": refusal_text}) | |
return final_refusal.strip() | |
if classification == "Wellness": | |
rag_result = wellness_rag_chain({"query": moderated_query}) | |
csv_answer = rag_result["result"].strip() | |
web_answer = "" # Empty if we found an answer from the knowledge base | |
if not csv_answer: | |
web_answer = await do_web_search(moderated_query) | |
final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer) | |
final_answer = tailor_chain.run({"response": final_merged}) | |
return final_answer.strip() | |
if classification == "Brand": | |
rag_result = brand_rag_chain({"query": moderated_query}) | |
csv_answer = rag_result["result"].strip() | |
final_merged = cleaner_chain.merge(kb=csv_answer, web="") | |
final_answer = tailor_chain.run({"response": final_merged}) | |
return final_answer.strip() | |
refusal_text = refusal_chain.run({"topic": "this topic"}) | |
final_refusal = tailor_chain.run({"response": refusal_text}) | |
return final_refusal.strip() | |
# Run the pipeline with the event loop | |
def run_with_chain(query: str) -> str: | |
return asyncio.run(run_async_pipeline(query)) | |