Phoenix21 commited on
Commit
b0739e4
·
verified ·
1 Parent(s): c09fe62

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +26 -25
pipeline.py CHANGED
@@ -4,14 +4,13 @@ import spacy
4
  import pandas as pd
5
  from typing import Optional
6
  import subprocess
7
- import asyncio # Needed for managing async tasks
8
  from langchain.llms.base import LLM
9
  from langchain.docstore.document import Document
10
  from langchain.embeddings import HuggingFaceEmbeddings
11
  from langchain.vectorstores import FAISS
12
  from langchain.chains import RetrievalQA
13
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
14
- from pydantic_ai import Agent # Import Pydantic AI's Agent
15
  from mistralai import Mistral
16
  from langchain.prompts import PromptTemplate
17
 
@@ -26,9 +25,6 @@ from prompts import classification_prompt, refusal_prompt, tailor_prompt
26
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
27
  client = Mistral(api_key=mistral_api_key)
28
 
29
- # Initialize Pydantic AI Agent (for text validation)
30
- pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str)
31
-
32
  # Load spaCy model for NER and download it if not already installed
33
  def install_spacy_model():
34
  try:
@@ -67,19 +63,31 @@ def classify_query(query: str) -> str:
67
  classification = class_result.get("text", "").strip()
68
  return classification if classification != "OutOfScope" else "OutOfScope"
69
 
70
- # Function to moderate text using Mistral moderation API (synchronous version)
71
- def moderate_text(query: str) -> str:
 
 
 
 
72
  try:
73
- # Use Pydantic AI to validate the text
74
- pydantic_agent.run_sync(query) # Use sync run for Pydantic validation
75
- except Exception as e:
 
76
  print(f"Error validating text: {e}")
77
  return "Invalid text format."
 
 
 
 
 
 
 
78
 
79
  # Call the Mistral moderation API
80
  response = client.classifiers.moderate_chat(
81
  model="mistral-moderation-latest",
82
- inputs=[{"role": "user", "content": query}]
83
  )
84
 
85
  # Assuming the response is an object of type 'ClassificationResponse',
@@ -93,7 +101,7 @@ def moderate_text(query: str) -> str:
93
  categories.get("selfharm", False):
94
  return "OutOfScope"
95
 
96
- return query
97
 
98
 
99
  # Function to build or load the vector store from CSV data
@@ -147,7 +155,7 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
147
  return rag_chain
148
 
149
  # Function to perform web search using DuckDuckGo
150
- async def do_web_search(query: str) -> str:
151
  search_tool = DuckDuckGoSearchTool()
152
  web_agent = CodeAgent(tools=[search_tool], model=pydantic_agent)
153
  managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
@@ -158,13 +166,13 @@ async def do_web_search(query: str) -> str:
158
  return response
159
 
160
  # Function to combine web and knowledge base responses
161
- async def merge_responses(kb_answer: str, web_answer: str) -> str:
162
  # Merge both answers with a cohesive response
163
  final_answer = f"Knowledge Base Answer: {kb_answer}\n\nWeb Search Result: {web_answer}"
164
  return final_answer.strip()
165
 
166
  # Orchestrate the entire workflow
167
- async def run_async_pipeline(query: str) -> str:
168
  # Moderate the query for harmful content (sync)
169
  moderated_query = moderate_text(query)
170
  if moderated_query == "OutOfScope":
@@ -183,15 +191,15 @@ async def run_async_pipeline(query: str) -> str:
183
  csv_answer = rag_result["result"].strip()
184
  web_answer = "" # Empty if we found an answer from the knowledge base
185
  if not csv_answer:
186
- web_answer = await do_web_search(moderated_query)
187
- final_merged = await merge_responses(csv_answer, web_answer)
188
  final_answer = tailor_chain.run({"response": final_merged})
189
  return final_answer.strip()
190
 
191
  if classification == "Brand":
192
  rag_result = brand_rag_chain({"query": moderated_query})
193
  csv_answer = rag_result["result"].strip()
194
- final_merged = await merge_responses(csv_answer, "")
195
  final_answer = tailor_chain.run({"response": final_merged})
196
  return final_answer.strip()
197
 
@@ -199,13 +207,6 @@ async def run_async_pipeline(query: str) -> str:
199
  final_refusal = tailor_chain.run({"response": refusal_text})
200
  return final_refusal.strip()
201
 
202
- # Run the pipeline with the event loop
203
- import asyncio
204
-
205
- def run_with_chain(query: str) -> str:
206
- # Use asyncio.run to run the async pipeline, which ensures a fresh event loop
207
- return asyncio.run(run_async_pipeline(query))
208
-
209
  # Initialize chains here
210
  classification_chain = get_classification_chain()
211
  refusal_chain = get_refusal_chain()
 
4
  import pandas as pd
5
  from typing import Optional
6
  import subprocess
 
7
  from langchain.llms.base import LLM
8
  from langchain.docstore.document import Document
9
  from langchain.embeddings import HuggingFaceEmbeddings
10
  from langchain.vectorstores import FAISS
11
  from langchain.chains import RetrievalQA
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
+ from pydantic import BaseModel, ValidationError # Import Pydantic for text validation
14
  from mistralai import Mistral
15
  from langchain.prompts import PromptTemplate
16
 
 
25
  mistral_api_key = os.environ.get("MISTRAL_API_KEY")
26
  client = Mistral(api_key=mistral_api_key)
27
 
 
 
 
28
  # Load spaCy model for NER and download it if not already installed
29
  def install_spacy_model():
30
  try:
 
63
  classification = class_result.get("text", "").strip()
64
  return classification if classification != "OutOfScope" else "OutOfScope"
65
 
66
+ # Pydantic model for text validation
67
+ class TextInputModel(BaseModel):
68
+ text: str
69
+
70
+ # Function to validate the text input using Pydantic
71
+ def validate_text(query: str) -> str:
72
  try:
73
+ # Attempt to validate the query as a text input
74
+ TextInputModel(text=query)
75
+ return query
76
+ except ValidationError as e:
77
  print(f"Error validating text: {e}")
78
  return "Invalid text format."
79
+
80
+ # Function to moderate text using Mistral moderation API (synchronous version)
81
+ def moderate_text(query: str) -> str:
82
+ # Validate the text using Pydantic
83
+ validated_text = validate_text(query)
84
+ if validated_text == "Invalid text format.":
85
+ return validated_text
86
 
87
  # Call the Mistral moderation API
88
  response = client.classifiers.moderate_chat(
89
  model="mistral-moderation-latest",
90
+ inputs=[{"role": "user", "content": validated_text}]
91
  )
92
 
93
  # Assuming the response is an object of type 'ClassificationResponse',
 
101
  categories.get("selfharm", False):
102
  return "OutOfScope"
103
 
104
+ return validated_text
105
 
106
 
107
  # Function to build or load the vector store from CSV data
 
155
  return rag_chain
156
 
157
  # Function to perform web search using DuckDuckGo
158
+ def do_web_search(query: str) -> str:
159
  search_tool = DuckDuckGoSearchTool()
160
  web_agent = CodeAgent(tools=[search_tool], model=pydantic_agent)
161
  managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
 
166
  return response
167
 
168
  # Function to combine web and knowledge base responses
169
+ def merge_responses(kb_answer: str, web_answer: str) -> str:
170
  # Merge both answers with a cohesive response
171
  final_answer = f"Knowledge Base Answer: {kb_answer}\n\nWeb Search Result: {web_answer}"
172
  return final_answer.strip()
173
 
174
  # Orchestrate the entire workflow
175
+ def run_pipeline(query: str) -> str:
176
  # Moderate the query for harmful content (sync)
177
  moderated_query = moderate_text(query)
178
  if moderated_query == "OutOfScope":
 
191
  csv_answer = rag_result["result"].strip()
192
  web_answer = "" # Empty if we found an answer from the knowledge base
193
  if not csv_answer:
194
+ web_answer = do_web_search(moderated_query)
195
+ final_merged = merge_responses(csv_answer, web_answer)
196
  final_answer = tailor_chain.run({"response": final_merged})
197
  return final_answer.strip()
198
 
199
  if classification == "Brand":
200
  rag_result = brand_rag_chain({"query": moderated_query})
201
  csv_answer = rag_result["result"].strip()
202
+ final_merged = merge_responses(csv_answer, "")
203
  final_answer = tailor_chain.run({"response": final_merged})
204
  return final_answer.strip()
205
 
 
207
  final_refusal = tailor_chain.run({"response": refusal_text})
208
  return final_refusal.strip()
209
 
 
 
 
 
 
 
 
210
  # Initialize chains here
211
  classification_chain = get_classification_chain()
212
  refusal_chain = get_refusal_chain()