Phoenix21 commited on
Commit
78bd826
·
verified ·
1 Parent(s): 5ad1b40

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +21 -68
pipeline.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
2
  import getpass
 
 
3
  import spacy # Import spaCy for NER functionality
4
  import pandas as pd
5
  from typing import Optional
@@ -10,33 +12,18 @@ from langchain.chains import RetrievalQA
10
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
11
  import subprocess # Import subprocess to run shell commands
12
  from langchain.llms.base import LLM # Import LLM
13
- from mistralai import Mistral # Import Mistral for moderation
14
-
15
- # Import the chain builders from our separate files
16
- from classification_chain import get_classification_chain
17
- from refusal_chain import get_refusal_chain
18
- from tailor_chain import get_tailor_chain
19
- from cleaner_chain import get_cleaner_chain, CleanerChain
20
-
21
- # 1) Environment: set up keys if missing
22
- if not os.environ.get("GEMINI_API_KEY"):
23
- os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
24
- if not os.environ.get("GROQ_API_KEY"):
25
- os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
26
- if not os.environ.get("MISTRAL_API_KEY"):
27
- os.environ["MISTRAL_API_KEY"] = getpass.getpass("Enter your Mistral API Key: ")
28
-
29
- # Initialize Mistral client
30
- mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
31
-
32
- # 2) Load spaCy model for NER and download the spaCy model if not already installed
33
  def install_spacy_model():
34
  try:
35
- # Check if the model is already installed
36
  spacy.load("en_core_web_sm")
37
  print("spaCy model 'en_core_web_sm' is already installed.")
38
  except OSError:
39
- # If model is not installed, download it using subprocess
40
  print("Downloading spaCy model 'en_core_web_sm'...")
41
  subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
42
  print("spaCy model 'en_core_web_sm' downloaded successfully.")
@@ -47,46 +34,16 @@ install_spacy_model()
47
  # Load the spaCy model globally
48
  nlp = spacy.load("en_core_web_sm")
49
 
50
- # Function to extract the main topic using NER
51
- def extract_main_topic(query: str) -> str:
52
- """
53
- Extracts the main topic from the user's query using spaCy's NER.
54
- Returns the first named entity or noun found in the query.
55
- """
56
- doc = nlp(query) # Use the globally loaded spaCy model
57
-
58
- # Try to extract the main topic as a named entity (person, product, etc.)
59
- main_topic = None
60
- for ent in doc.ents:
61
- # Filter for specific entity types (you can adjust this based on your needs)
62
- if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]: # Add more entity labels as needed
63
- main_topic = ent.text
64
- break
65
-
66
- # If no named entity found, fallback to extracting the first noun or proper noun
67
- if not main_topic:
68
- for token in doc:
69
- if token.pos_ in ["NOUN", "PROPN"]: # Extract first noun or proper noun
70
- main_topic = token.text
71
- break
72
-
73
- # Return the extracted topic or a fallback value if no topic is found
74
- return main_topic if main_topic else "this topic"
75
-
76
- # 3) Function to moderate text using Mistral moderation API
77
  def moderate_text(query: str) -> str:
78
  """
79
- Classifies the query as harmful or not using Mistral Moderation API.
80
  Returns "OutOfScope" if harmful, otherwise returns the original query.
81
  """
82
- response = mistral_client.classifiers.moderate(
83
- model="mistral-moderation-latest",
84
- inputs=[query]
85
- )
86
-
87
- categories = response.results[0].categories
88
-
89
- # Check if any harmful category is flagged
90
  if categories.get("violence_and_threats", False) or \
91
  categories.get("hate_and_discrimination", False) or \
92
  categories.get("dangerous_and_criminal_content", False) or \
@@ -94,7 +51,7 @@ def moderate_text(query: str) -> str:
94
  return "OutOfScope"
95
  return query
96
 
97
- # 4) build_or_load_vectorstore (no changes)
98
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
99
  if os.path.exists(store_dir):
100
  print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
@@ -123,7 +80,7 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
123
  vectorstore.save_local(store_dir)
124
  return vectorstore
125
 
126
- # 5) Build RAG chain for Gemini (no changes)
127
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
128
  class GeminiLangChainLLM(LLM):
129
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
@@ -144,13 +101,13 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
144
  )
145
  return rag_chain
146
 
147
- # 6) Initialize all the separate chains
148
  classification_chain = get_classification_chain()
149
  refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic
150
  tailor_chain = get_tailor_chain()
151
  cleaner_chain = get_cleaner_chain()
152
 
153
- # 7) Build our vectorstores + RAG chains
154
  wellness_csv = "AIChatbot.csv"
155
  brand_csv = "BrandAI.csv"
156
  wellness_store_dir = "faiss_wellness_store"
@@ -163,7 +120,7 @@ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("
163
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
164
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
165
 
166
- # 8) Tools / Agents for web search (no changes)
167
  search_tool = DuckDuckGoSearchTool()
168
  web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
169
  managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
@@ -175,14 +132,10 @@ def do_web_search(query: str) -> str:
175
  response = manager_agent.run(search_query)
176
  return response
177
 
178
- # 9) Orchestrator: run_with_chain
179
  def run_with_chain(query: str) -> str:
180
  print("DEBUG: Starting run_with_chain...")
181
 
182
-
183
- # Ensure the query is a string
184
- query = str(query).strip()
185
-
186
  # 1) Moderate the query for harmful content
187
  moderated_query = moderate_text(query)
188
  if moderated_query == "OutOfScope":
 
1
  import os
2
  import getpass
3
+ from pydantic_ai import Agent # Import the Agent from pydantic_ai
4
+ from pydantic_ai.models.mistral import MistralModel # Import the Mistral model
5
  import spacy # Import spaCy for NER functionality
6
  import pandas as pd
7
  from typing import Optional
 
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
  import subprocess # Import subprocess to run shell commands
14
  from langchain.llms.base import LLM # Import LLM
15
+
16
+ # Initialize Mistral agent using Pydantic AI
17
+ mistral_api_key = os.environ.get("MISTRAL_API_KEY") # Ensure your Mistral API key is set
18
+ mistral_model = MistralModel("mistral-large-latest", api_key=mistral_api_key) # Use a Mistral model
19
+ mistral_agent = Agent(mistral_model)
20
+
21
+ # Load spaCy model for NER and download the spaCy model if not already installed
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def install_spacy_model():
23
  try:
 
24
  spacy.load("en_core_web_sm")
25
  print("spaCy model 'en_core_web_sm' is already installed.")
26
  except OSError:
 
27
  print("Downloading spaCy model 'en_core_web_sm'...")
28
  subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
29
  print("spaCy model 'en_core_web_sm' downloaded successfully.")
 
34
  # Load the spaCy model globally
35
  nlp = spacy.load("en_core_web_sm")
36
 
37
+ # Function to moderate text using Pydantic AI's Mistral moderation model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def moderate_text(query: str) -> str:
39
  """
40
+ Classifies the query as harmful or not using Mistral Moderation via Pydantic AI.
41
  Returns "OutOfScope" if harmful, otherwise returns the original query.
42
  """
43
+ response = mistral_agent.call("classify", {"inputs": [query]})
44
+ categories = response['results'][0]['categories']
45
+
46
+ # Check if harmful content is flagged in moderation categories
 
 
 
 
47
  if categories.get("violence_and_threats", False) or \
48
  categories.get("hate_and_discrimination", False) or \
49
  categories.get("dangerous_and_criminal_content", False) or \
 
51
  return "OutOfScope"
52
  return query
53
 
54
+ # 3) build_or_load_vectorstore (no changes)
55
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
56
  if os.path.exists(store_dir):
57
  print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
 
80
  vectorstore.save_local(store_dir)
81
  return vectorstore
82
 
83
+ # 4) Build RAG chain for Gemini (no changes)
84
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
85
  class GeminiLangChainLLM(LLM):
86
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
 
101
  )
102
  return rag_chain
103
 
104
+ # 5) Initialize all the separate chains
105
  classification_chain = get_classification_chain()
106
  refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic
107
  tailor_chain = get_tailor_chain()
108
  cleaner_chain = get_cleaner_chain()
109
 
110
+ # 6) Build our vectorstores + RAG chains
111
  wellness_csv = "AIChatbot.csv"
112
  brand_csv = "BrandAI.csv"
113
  wellness_store_dir = "faiss_wellness_store"
 
120
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
121
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
122
 
123
+ # 7) Tools / Agents for web search (no changes)
124
  search_tool = DuckDuckGoSearchTool()
125
  web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
126
  managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
 
132
  response = manager_agent.run(search_query)
133
  return response
134
 
135
+ # 8) Orchestrator: run_with_chain
136
  def run_with_chain(query: str) -> str:
137
  print("DEBUG: Starting run_with_chain...")
138
 
 
 
 
 
139
  # 1) Moderate the query for harmful content
140
  moderated_query = moderate_text(query)
141
  if moderated_query == "OutOfScope":