File size: 3,277 Bytes
74b1bac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import logging
import os
from dotenv import load_dotenv
from typing import Optional
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
import requests
load_dotenv()

policy_mapping = {
    "S1": "Tội phạm bạo lực",
    "S2": "Tội phạm không bạo lực",
    "S3": "Tội phạm liên quan đến tình dục",
    "S4": "Khai thác tình dục trẻ em",
    "S5": "Phỉ báng",
    "S6": "Tư vấn chuyên biệt",
    "S7": "Riêng tư",
    "S8": "Sở hữu trí tuệ",
    "S9": "Vũ khí không phân biệt",
    "S10": "Thù ghét",
    "S11": "Tự sát & Tự gây hại",
    "S12": "Nội dung tình dục",
    "S13": "Bầu cử",
    "S14": "Ngôn ngữ xúc phạm và thô tục"
}

class SafetyChecker(BaseModel):
    model_name: str = Field(default="llama-guard-3-8b")
    temperature: float = Field(default=0.0)
    api_key: Optional[str] = Field(default=None)
    prompt_safe: Optional[ChatPromptTemplate] = None
    system_prompt: str = Field(default="""Task type
Check if there is unsafe content in the messages in conversations according to our safety policy with the below categories.

Policy
<BEGIN UNSAFE CONTENT CATEGORIES>
S1: Violent Crimes.
S2: Non-Violent Crimes
S3: Sex-Related Crimes
S4: Child Sexual Exploitation
S5: Defamation
S6: Specialized Advice
S7: Privacy
S8: Intellectual Property
S9: Indiscriminate Weapon
S10: Hate
S11: Suicide & Self-Harm
S12: Sexual Content
S13: Elections
S14: Offensive and Vulgar Language
<END UNSAFE CONTENT CATEGORIES>
Output format
Provide your safety assessment for user message in the below :
First line must read 'safe' or 'unsafe'.

If unsafe, a second line must include a comma-separated list of violated categories.
""")
    
 

    class Config:
        arbitrary_types_allowed = True

    def __init__(self, api_key: Optional[str] = None, **data):
        super().__init__(**data)
        load_dotenv()
        self.api_key = api_key
        # Initialize the ChatPromptTemplate as an instance attribute
        self.prompt_safe = ChatPromptTemplate.from_messages(
            [
                ("system", self.system_prompt),
                ("human", "{question}"),
            ]
        )

    def check_safety(self, question: str) -> str:
        llm = ChatGroq(model_name=self.model_name, temperature=self.temperature, api_key=self.api_key)
        question_save = self.prompt_safe | llm
        out = question_save.invoke({"question": question})
        result_lines = out.content.strip().split('\n')
        
        if result_lines[0] == 'unsafe':
            # Retrieve the violated categories
            violated_categories = result_lines[1]
            categories = violated_categories.split(',')
            
            # Trim whitespace and look up descriptions
            category_descriptions = [policy_mapping.get(cat.strip(), cat.strip()) for cat in categories]
            
            # Join descriptions into a single string
            descriptions_str = ', '.join(category_descriptions)
            
            return f"Câu hỏi không được cho phép vì vi phạm chính sách an toàn cộng đồng : {descriptions_str}"
        
        return "safe"