Spaces:
Running
Running
Update pipeline.py
Browse files- pipeline.py +26 -25
pipeline.py
CHANGED
@@ -4,14 +4,13 @@ import spacy
|
|
4 |
import pandas as pd
|
5 |
from typing import Optional
|
6 |
import subprocess
|
7 |
-
import asyncio # Needed for managing async tasks
|
8 |
from langchain.llms.base import LLM
|
9 |
from langchain.docstore.document import Document
|
10 |
from langchain.embeddings import HuggingFaceEmbeddings
|
11 |
from langchain.vectorstores import FAISS
|
12 |
from langchain.chains import RetrievalQA
|
13 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
14 |
-
from
|
15 |
from mistralai import Mistral
|
16 |
from langchain.prompts import PromptTemplate
|
17 |
|
@@ -26,9 +25,6 @@ from prompts import classification_prompt, refusal_prompt, tailor_prompt
|
|
26 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
27 |
client = Mistral(api_key=mistral_api_key)
|
28 |
|
29 |
-
# Initialize Pydantic AI Agent (for text validation)
|
30 |
-
pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str)
|
31 |
-
|
32 |
# Load spaCy model for NER and download it if not already installed
|
33 |
def install_spacy_model():
|
34 |
try:
|
@@ -67,19 +63,31 @@ def classify_query(query: str) -> str:
|
|
67 |
classification = class_result.get("text", "").strip()
|
68 |
return classification if classification != "OutOfScope" else "OutOfScope"
|
69 |
|
70 |
-
#
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
try:
|
73 |
-
#
|
74 |
-
|
75 |
-
|
|
|
76 |
print(f"Error validating text: {e}")
|
77 |
return "Invalid text format."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
# Call the Mistral moderation API
|
80 |
response = client.classifiers.moderate_chat(
|
81 |
model="mistral-moderation-latest",
|
82 |
-
inputs=[{"role": "user", "content":
|
83 |
)
|
84 |
|
85 |
# Assuming the response is an object of type 'ClassificationResponse',
|
@@ -93,7 +101,7 @@ def moderate_text(query: str) -> str:
|
|
93 |
categories.get("selfharm", False):
|
94 |
return "OutOfScope"
|
95 |
|
96 |
-
return
|
97 |
|
98 |
|
99 |
# Function to build or load the vector store from CSV data
|
@@ -147,7 +155,7 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
|
147 |
return rag_chain
|
148 |
|
149 |
# Function to perform web search using DuckDuckGo
|
150 |
-
|
151 |
search_tool = DuckDuckGoSearchTool()
|
152 |
web_agent = CodeAgent(tools=[search_tool], model=pydantic_agent)
|
153 |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
|
@@ -158,13 +166,13 @@ async def do_web_search(query: str) -> str:
|
|
158 |
return response
|
159 |
|
160 |
# Function to combine web and knowledge base responses
|
161 |
-
|
162 |
# Merge both answers with a cohesive response
|
163 |
final_answer = f"Knowledge Base Answer: {kb_answer}\n\nWeb Search Result: {web_answer}"
|
164 |
return final_answer.strip()
|
165 |
|
166 |
# Orchestrate the entire workflow
|
167 |
-
|
168 |
# Moderate the query for harmful content (sync)
|
169 |
moderated_query = moderate_text(query)
|
170 |
if moderated_query == "OutOfScope":
|
@@ -183,15 +191,15 @@ async def run_async_pipeline(query: str) -> str:
|
|
183 |
csv_answer = rag_result["result"].strip()
|
184 |
web_answer = "" # Empty if we found an answer from the knowledge base
|
185 |
if not csv_answer:
|
186 |
-
web_answer =
|
187 |
-
final_merged =
|
188 |
final_answer = tailor_chain.run({"response": final_merged})
|
189 |
return final_answer.strip()
|
190 |
|
191 |
if classification == "Brand":
|
192 |
rag_result = brand_rag_chain({"query": moderated_query})
|
193 |
csv_answer = rag_result["result"].strip()
|
194 |
-
final_merged =
|
195 |
final_answer = tailor_chain.run({"response": final_merged})
|
196 |
return final_answer.strip()
|
197 |
|
@@ -199,13 +207,6 @@ async def run_async_pipeline(query: str) -> str:
|
|
199 |
final_refusal = tailor_chain.run({"response": refusal_text})
|
200 |
return final_refusal.strip()
|
201 |
|
202 |
-
# Run the pipeline with the event loop
|
203 |
-
import asyncio
|
204 |
-
|
205 |
-
def run_with_chain(query: str) -> str:
|
206 |
-
# Use asyncio.run to run the async pipeline, which ensures a fresh event loop
|
207 |
-
return asyncio.run(run_async_pipeline(query))
|
208 |
-
|
209 |
# Initialize chains here
|
210 |
classification_chain = get_classification_chain()
|
211 |
refusal_chain = get_refusal_chain()
|
|
|
4 |
import pandas as pd
|
5 |
from typing import Optional
|
6 |
import subprocess
|
|
|
7 |
from langchain.llms.base import LLM
|
8 |
from langchain.docstore.document import Document
|
9 |
from langchain.embeddings import HuggingFaceEmbeddings
|
10 |
from langchain.vectorstores import FAISS
|
11 |
from langchain.chains import RetrievalQA
|
12 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
13 |
+
from pydantic import BaseModel, ValidationError # Import Pydantic for text validation
|
14 |
from mistralai import Mistral
|
15 |
from langchain.prompts import PromptTemplate
|
16 |
|
|
|
25 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
26 |
client = Mistral(api_key=mistral_api_key)
|
27 |
|
|
|
|
|
|
|
28 |
# Load spaCy model for NER and download it if not already installed
|
29 |
def install_spacy_model():
|
30 |
try:
|
|
|
63 |
classification = class_result.get("text", "").strip()
|
64 |
return classification if classification != "OutOfScope" else "OutOfScope"
|
65 |
|
66 |
+
# Pydantic model for text validation
|
67 |
+
class TextInputModel(BaseModel):
|
68 |
+
text: str
|
69 |
+
|
70 |
+
# Function to validate the text input using Pydantic
|
71 |
+
def validate_text(query: str) -> str:
|
72 |
try:
|
73 |
+
# Attempt to validate the query as a text input
|
74 |
+
TextInputModel(text=query)
|
75 |
+
return query
|
76 |
+
except ValidationError as e:
|
77 |
print(f"Error validating text: {e}")
|
78 |
return "Invalid text format."
|
79 |
+
|
80 |
+
# Function to moderate text using Mistral moderation API (synchronous version)
|
81 |
+
def moderate_text(query: str) -> str:
|
82 |
+
# Validate the text using Pydantic
|
83 |
+
validated_text = validate_text(query)
|
84 |
+
if validated_text == "Invalid text format.":
|
85 |
+
return validated_text
|
86 |
|
87 |
# Call the Mistral moderation API
|
88 |
response = client.classifiers.moderate_chat(
|
89 |
model="mistral-moderation-latest",
|
90 |
+
inputs=[{"role": "user", "content": validated_text}]
|
91 |
)
|
92 |
|
93 |
# Assuming the response is an object of type 'ClassificationResponse',
|
|
|
101 |
categories.get("selfharm", False):
|
102 |
return "OutOfScope"
|
103 |
|
104 |
+
return validated_text
|
105 |
|
106 |
|
107 |
# Function to build or load the vector store from CSV data
|
|
|
155 |
return rag_chain
|
156 |
|
157 |
# Function to perform web search using DuckDuckGo
|
158 |
+
def do_web_search(query: str) -> str:
|
159 |
search_tool = DuckDuckGoSearchTool()
|
160 |
web_agent = CodeAgent(tools=[search_tool], model=pydantic_agent)
|
161 |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
|
|
|
166 |
return response
|
167 |
|
168 |
# Function to combine web and knowledge base responses
|
169 |
+
def merge_responses(kb_answer: str, web_answer: str) -> str:
|
170 |
# Merge both answers with a cohesive response
|
171 |
final_answer = f"Knowledge Base Answer: {kb_answer}\n\nWeb Search Result: {web_answer}"
|
172 |
return final_answer.strip()
|
173 |
|
174 |
# Orchestrate the entire workflow
|
175 |
+
def run_pipeline(query: str) -> str:
|
176 |
# Moderate the query for harmful content (sync)
|
177 |
moderated_query = moderate_text(query)
|
178 |
if moderated_query == "OutOfScope":
|
|
|
191 |
csv_answer = rag_result["result"].strip()
|
192 |
web_answer = "" # Empty if we found an answer from the knowledge base
|
193 |
if not csv_answer:
|
194 |
+
web_answer = do_web_search(moderated_query)
|
195 |
+
final_merged = merge_responses(csv_answer, web_answer)
|
196 |
final_answer = tailor_chain.run({"response": final_merged})
|
197 |
return final_answer.strip()
|
198 |
|
199 |
if classification == "Brand":
|
200 |
rag_result = brand_rag_chain({"query": moderated_query})
|
201 |
csv_answer = rag_result["result"].strip()
|
202 |
+
final_merged = merge_responses(csv_answer, "")
|
203 |
final_answer = tailor_chain.run({"response": final_merged})
|
204 |
return final_answer.strip()
|
205 |
|
|
|
207 |
final_refusal = tailor_chain.run({"response": refusal_text})
|
208 |
return final_refusal.strip()
|
209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
# Initialize chains here
|
211 |
classification_chain = get_classification_chain()
|
212 |
refusal_chain = get_refusal_chain()
|