|
import logging |
|
import os |
|
from dotenv import load_dotenv |
|
from typing import Optional |
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_groq import ChatGroq |
|
import requests |
|
load_dotenv() |
|
|
|
policy_mapping = { |
|
"S1": "Tội phạm bạo lực", |
|
"S2": "Tội phạm không bạo lực", |
|
"S3": "Tội phạm liên quan đến tình dục", |
|
"S4": "Khai thác tình dục trẻ em", |
|
"S5": "Phỉ báng", |
|
"S6": "Tư vấn chuyên biệt", |
|
"S7": "Riêng tư", |
|
"S8": "Sở hữu trí tuệ", |
|
"S9": "Vũ khí không phân biệt", |
|
"S10": "Thù ghét", |
|
"S11": "Tự sát & Tự gây hại", |
|
"S12": "Nội dung tình dục", |
|
"S13": "Bầu cử", |
|
"S14": "Ngôn ngữ xúc phạm và thô tục" |
|
} |
|
|
|
class SafetyChecker(BaseModel): |
|
model_name: str = Field(default="llama-guard-3-8b") |
|
temperature: float = Field(default=0.0) |
|
api_key: Optional[str] = Field(default=None) |
|
prompt_safe: Optional[ChatPromptTemplate] = None |
|
system_prompt: str = Field(default="""Task type |
|
Check if there is unsafe content in the messages in conversations according to our safety policy with the below categories. |
|
|
|
Policy |
|
<BEGIN UNSAFE CONTENT CATEGORIES> |
|
S1: Violent Crimes. |
|
S2: Non-Violent Crimes |
|
S3: Sex-Related Crimes |
|
S4: Child Sexual Exploitation |
|
S5: Defamation |
|
S6: Specialized Advice |
|
S7: Privacy |
|
S8: Intellectual Property |
|
S9: Indiscriminate Weapon |
|
S10: Hate |
|
S11: Suicide & Self-Harm |
|
S12: Sexual Content |
|
S13: Elections |
|
S14: Offensive and Vulgar Language |
|
<END UNSAFE CONTENT CATEGORIES> |
|
Output format |
|
Provide your safety assessment for user message in the below : |
|
First line must read 'safe' or 'unsafe'. |
|
|
|
If unsafe, a second line must include a comma-separated list of violated categories. |
|
""") |
|
|
|
|
|
|
|
class Config: |
|
arbitrary_types_allowed = True |
|
|
|
def __init__(self, api_key: Optional[str] = None, **data): |
|
super().__init__(**data) |
|
load_dotenv() |
|
self.api_key = api_key |
|
|
|
self.prompt_safe = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", self.system_prompt), |
|
("human", "{question}"), |
|
] |
|
) |
|
|
|
def check_safety(self, question: str) -> str: |
|
llm = ChatGroq(model_name=self.model_name, temperature=self.temperature, api_key=self.api_key) |
|
question_save = self.prompt_safe | llm |
|
out = question_save.invoke({"question": question}) |
|
result_lines = out.content.strip().split('\n') |
|
|
|
if result_lines[0] == 'unsafe': |
|
|
|
violated_categories = result_lines[1] |
|
categories = violated_categories.split(',') |
|
|
|
|
|
category_descriptions = [policy_mapping.get(cat.strip(), cat.strip()) for cat in categories] |
|
|
|
|
|
descriptions_str = ', '.join(category_descriptions) |
|
|
|
return f"Câu hỏi không được cho phép vì vi phạm chính sách an toàn cộng đồng : {descriptions_str}" |
|
|
|
return "safe" |
|
|
|
|
|
|