import logging import os from dotenv import load_dotenv from typing import Optional from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.prompts import ChatPromptTemplate from langchain_groq import ChatGroq import requests load_dotenv() policy_mapping = { "S1": "Tội phạm bạo lực", "S2": "Tội phạm không bạo lực", "S3": "Tội phạm liên quan đến tình dục", "S4": "Khai thác tình dục trẻ em", "S5": "Phỉ báng", "S6": "Tư vấn chuyên biệt", "S7": "Riêng tư", "S8": "Sở hữu trí tuệ", "S9": "Vũ khí không phân biệt", "S10": "Thù ghét", "S11": "Tự sát & Tự gây hại", "S12": "Nội dung tình dục", "S13": "Bầu cử", "S14": "Ngôn ngữ xúc phạm và thô tục" } class SafetyChecker(BaseModel): model_name: str = Field(default="llama-guard-3-8b") temperature: float = Field(default=0.0) api_key: Optional[str] = Field(default=None) prompt_safe: Optional[ChatPromptTemplate] = None system_prompt: str = Field(default="""Task type Check if there is unsafe content in the messages in conversations according to our safety policy with the below categories. Policy S1: Violent Crimes. S2: Non-Violent Crimes S3: Sex-Related Crimes S4: Child Sexual Exploitation S5: Defamation S6: Specialized Advice S7: Privacy S8: Intellectual Property S9: Indiscriminate Weapon S10: Hate S11: Suicide & Self-Harm S12: Sexual Content S13: Elections S14: Offensive and Vulgar Language Output format Provide your safety assessment for user message in the below : First line must read 'safe' or 'unsafe'. If unsafe, a second line must include a comma-separated list of violated categories. """) class Config: arbitrary_types_allowed = True def __init__(self, api_key: Optional[str] = None, **data): super().__init__(**data) load_dotenv() self.api_key = api_key # Initialize the ChatPromptTemplate as an instance attribute self.prompt_safe = ChatPromptTemplate.from_messages( [ ("system", self.system_prompt), ("human", "{question}"), ] ) def check_safety(self, question: str) -> str: llm = ChatGroq(model_name=self.model_name, temperature=self.temperature, api_key=self.api_key) question_save = self.prompt_safe | llm out = question_save.invoke({"question": question}) result_lines = out.content.strip().split('\n') if result_lines[0] == 'unsafe': # Retrieve the violated categories violated_categories = result_lines[1] categories = violated_categories.split(',') # Trim whitespace and look up descriptions category_descriptions = [policy_mapping.get(cat.strip(), cat.strip()) for cat in categories] # Join descriptions into a single string descriptions_str = ', '.join(category_descriptions) return f"Câu hỏi không được cho phép vì vi phạm chính sách an toàn cộng đồng : {descriptions_str}" return "safe"