Phoenix21 commited on
Commit
0aef3aa
·
verified ·
1 Parent(s): 53b33ac

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +17 -15
pipeline.py CHANGED
@@ -67,33 +67,35 @@ def classify_query(query: str) -> str:
67
  classification = class_result.get("text", "").strip()
68
  return classification if classification != "OutOfScope" else "OutOfScope"
69
 
70
- # Function to moderate text using Mistral moderation API (sync version)
71
- def moderate_text(query: str) -> str:
72
  try:
73
- # Use Pydantic AI for text validation synchronously
74
- pydantic_agent.run(query) # This is a synchronous call
75
  except Exception as e:
76
  print(f"Error validating text: {e}")
77
  return "Invalid text format."
78
 
79
- # Mistral moderation, no need for await as it's synchronous
80
- response = client.classifiers.moderate_chat(
81
  model="mistral-moderation-latest",
82
  inputs=[{"role": "user", "content": query}]
83
  )
84
 
85
- # Extract moderation categories
86
- categories = response['results'][0]['categories']
87
-
88
- # Check for harmful categories and return "OutOfScope" if any are found
89
- if categories.get("violence_and_threats", False) or \
90
- categories.get("hate_and_discrimination", False) or \
91
- categories.get("dangerous_and_criminal_content", False) or \
92
- categories.get("selfharm", False):
93
- return "OutOfScope"
 
94
 
95
  return query
96
 
 
97
  # Function to build or load the vector store from CSV data
98
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
99
  if os.path.exists(store_dir):
 
67
  classification = class_result.get("text", "").strip()
68
  return classification if classification != "OutOfScope" else "OutOfScope"
69
 
70
+ # Function to moderate text using Mistral moderation API (async version)
71
+ async def moderate_text(query: str) -> str:
72
  try:
73
+ # Use Pydantic AI to validate the text
74
+ await pydantic_agent.run(query) # Use async run for Pydantic validation
75
  except Exception as e:
76
  print(f"Error validating text: {e}")
77
  return "Invalid text format."
78
 
79
+ # Call the Mistral moderation API
80
+ response = await client.classifiers.moderate_chat(
81
  model="mistral-moderation-latest",
82
  inputs=[{"role": "user", "content": query}]
83
  )
84
 
85
+ # Assuming the response is an object of type 'ClassificationResponse',
86
+ # check if it has a 'results' attribute, and then access its categories
87
+ if hasattr(response, 'results') and response.results:
88
+ categories = response.results[0].categories
89
+ # Check if harmful categories are present
90
+ if categories.get("violence_and_threats", False) or \
91
+ categories.get("hate_and_discrimination", False) or \
92
+ categories.get("dangerous_and_criminal_content", False) or \
93
+ categories.get("selfharm", False):
94
+ return "OutOfScope"
95
 
96
  return query
97
 
98
+
99
  # Function to build or load the vector store from CSV data
100
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
101
  if os.path.exists(store_dir):