Phoenix21 commited on
Commit
db87ae8
·
verified ·
1 Parent(s): 21ce388

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +11 -24
pipeline.py CHANGED
@@ -1,7 +1,5 @@
1
  import os
2
  import getpass
3
- from pydantic_ai import Agent # Import the Agent from pydantic_ai
4
- from pydantic_ai.models.mistral import MistralModel # Import the Mistral model
5
  import spacy # Import spaCy for NER functionality
6
  import pandas as pd
7
  from typing import Optional
@@ -13,25 +11,12 @@ from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMMod
13
  import subprocess # Import subprocess to run shell commands
14
  from langchain.llms.base import LLM # Import LLM
15
 
16
- from classification_chain import get_classification_chain
17
- from refusal_chain import get_refusal_chain
18
- from tailor_chain import get_tailor_chain
19
- from cleaner_chain import get_cleaner_chain, CleanerChain
20
-
21
-
22
- # 1) Environment: set up keys if missing
23
- if not os.environ.get("GEMINI_API_KEY"):
24
- os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
25
- if not os.environ.get("GROQ_API_KEY"):
26
- os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
27
- if not os.environ.get("MISTRAL_API_KEY"):
28
- os.environ["MISTRAL_API_KEY"] = getpass.getpass("Enter your Mistral API Key: ")
29
-
30
- # Initialize Mistral agent using Pydantic AI
31
 
 
32
  mistral_api_key = os.environ.get("MISTRAL_API_KEY") # Ensure your Mistral API key is set
33
- mistral_model = MistralModel("mistral-moderation-latest", api_key=mistral_api_key) # Use a Mistral model
34
- mistral_agent = Agent(mistral_model)
35
 
36
  # Load spaCy model for NER and download the spaCy model if not already installed
37
  def install_spacy_model():
@@ -49,21 +34,21 @@ install_spacy_model()
49
  # Load the spaCy model globally
50
  nlp = spacy.load("en_core_web_sm")
51
 
52
- # Function to moderate text using Pydantic AI's Mistral moderation model
53
  def moderate_text(query: str) -> str:
54
  """
55
- Classifies the query as harmful or not using Mistral Moderation via Pydantic AI.
56
  Returns "OutOfScope" if harmful, otherwise returns the original query.
57
  """
58
  # Use the moderation API to evaluate if the query is harmful
59
- response = mistral_agent.model.classifiers.moderate_chat(
60
  model="mistral-moderation-latest",
61
  inputs=[
62
  {"role": "user", "content": query},
63
  ],
64
  )
65
-
66
- # Assuming the response contains 'results' with category scores
67
  categories = response['results'][0]['categories']
68
 
69
  # Check if harmful content is flagged in moderation categories
@@ -72,7 +57,9 @@ def moderate_text(query: str) -> str:
72
  categories.get("dangerous_and_criminal_content", False) or \
73
  categories.get("selfharm", False):
74
  return "OutOfScope"
 
75
  return query
 
76
  # 3) build_or_load_vectorstore (no changes)
77
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
78
  if os.path.exists(store_dir):
 
1
  import os
2
  import getpass
 
 
3
  import spacy # Import spaCy for NER functionality
4
  import pandas as pd
5
  from typing import Optional
 
11
  import subprocess # Import subprocess to run shell commands
12
  from langchain.llms.base import LLM # Import LLM
13
 
14
+ # Mistral Client Setup
15
+ from mistralai import Mistral # Import the Mistral client
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Initialize Mistral API client
18
  mistral_api_key = os.environ.get("MISTRAL_API_KEY") # Ensure your Mistral API key is set
19
+ client = Mistral(api_key=mistral_api_key)
 
20
 
21
  # Load spaCy model for NER and download the spaCy model if not already installed
22
  def install_spacy_model():
 
34
  # Load the spaCy model globally
35
  nlp = spacy.load("en_core_web_sm")
36
 
37
+ # Function to moderate text using Mistral moderation API
38
  def moderate_text(query: str) -> str:
39
  """
40
+ Classifies the query as harmful or not using Mistral Moderation via Mistral API.
41
  Returns "OutOfScope" if harmful, otherwise returns the original query.
42
  """
43
  # Use the moderation API to evaluate if the query is harmful
44
+ response = client.classifiers.moderate_chat(
45
  model="mistral-moderation-latest",
46
  inputs=[
47
  {"role": "user", "content": query},
48
  ],
49
  )
50
+
51
+ # Extracting category scores from response
52
  categories = response['results'][0]['categories']
53
 
54
  # Check if harmful content is flagged in moderation categories
 
57
  categories.get("dangerous_and_criminal_content", False) or \
58
  categories.get("selfharm", False):
59
  return "OutOfScope"
60
+
61
  return query
62
+
63
  # 3) build_or_load_vectorstore (no changes)
64
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
65
  if os.path.exists(store_dir):