Graduation / pipelines /SafetyChecker /safety_checker.py
DuyTa's picture
Upload folder using huggingface_hub
74b1bac verified
raw
history blame
3.28 kB
import logging
import os
from dotenv import load_dotenv
from typing import Optional
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
import requests
load_dotenv()
policy_mapping = {
"S1": "Tội phạm bạo lực",
"S2": "Tội phạm không bạo lực",
"S3": "Tội phạm liên quan đến tình dục",
"S4": "Khai thác tình dục trẻ em",
"S5": "Phỉ báng",
"S6": "Tư vấn chuyên biệt",
"S7": "Riêng tư",
"S8": "Sở hữu trí tuệ",
"S9": "Vũ khí không phân biệt",
"S10": "Thù ghét",
"S11": "Tự sát & Tự gây hại",
"S12": "Nội dung tình dục",
"S13": "Bầu cử",
"S14": "Ngôn ngữ xúc phạm và thô tục"
}
class SafetyChecker(BaseModel):
model_name: str = Field(default="llama-guard-3-8b")
temperature: float = Field(default=0.0)
api_key: Optional[str] = Field(default=None)
prompt_safe: Optional[ChatPromptTemplate] = None
system_prompt: str = Field(default="""Task type
Check if there is unsafe content in the messages in conversations according to our safety policy with the below categories.
Policy
<BEGIN UNSAFE CONTENT CATEGORIES>
S1: Violent Crimes.
S2: Non-Violent Crimes
S3: Sex-Related Crimes
S4: Child Sexual Exploitation
S5: Defamation
S6: Specialized Advice
S7: Privacy
S8: Intellectual Property
S9: Indiscriminate Weapon
S10: Hate
S11: Suicide & Self-Harm
S12: Sexual Content
S13: Elections
S14: Offensive and Vulgar Language
<END UNSAFE CONTENT CATEGORIES>
Output format
Provide your safety assessment for user message in the below :
First line must read 'safe' or 'unsafe'.
If unsafe, a second line must include a comma-separated list of violated categories.
""")
class Config:
arbitrary_types_allowed = True
def __init__(self, api_key: Optional[str] = None, **data):
super().__init__(**data)
load_dotenv()
self.api_key = api_key
# Initialize the ChatPromptTemplate as an instance attribute
self.prompt_safe = ChatPromptTemplate.from_messages(
[
("system", self.system_prompt),
("human", "{question}"),
]
)
def check_safety(self, question: str) -> str:
llm = ChatGroq(model_name=self.model_name, temperature=self.temperature, api_key=self.api_key)
question_save = self.prompt_safe | llm
out = question_save.invoke({"question": question})
result_lines = out.content.strip().split('\n')
if result_lines[0] == 'unsafe':
# Retrieve the violated categories
violated_categories = result_lines[1]
categories = violated_categories.split(',')
# Trim whitespace and look up descriptions
category_descriptions = [policy_mapping.get(cat.strip(), cat.strip()) for cat in categories]
# Join descriptions into a single string
descriptions_str = ', '.join(category_descriptions)
return f"Câu hỏi không được cho phép vì vi phạm chính sách an toàn cộng đồng : {descriptions_str}"
return "safe"