File size: 4,009 Bytes
46358a2
e2fac8d
 
6293678
29e0785
6293678
66e2112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46358a2
66e2112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2fac8d
66e2112
 
 
 
 
 
 
 
e2fac8d
66e2112
 
 
 
 
 
 
 
 
 
 
 
0a17bfe
66e2112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a17bfe
66e2112
 
e2fac8d
66e2112
e2fac8d
66e2112
e2fac8d
 
 
 
ca0aa0f
 
 
 
 
e2fac8d
 
 
29e0785
e2fac8d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import gradio as gr
import spaces

class LlamaGuardModeration:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_id = "meta-llama/Llama-Guard-3-8B-INT8"
        self.dtype = torch.bfloat16
        
        # HuggingFace tokenの取得
        self.huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
        if not self.huggingface_token:
            raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
        
        # モデルの初期化
        self.initialize_model()

    def initialize_model(self):
        """モデルとトークナイザーの初期化"""
        if self.model is None:
            # quantization_configの設定
            quantization_config = BitsAndBytesConfig(
                load_in_8bit=True,
                bnb_4bit_compute_dtype=self.dtype
            )
            
            # トークナイザーの初期化
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_id, 
                token=self.huggingface_token
            )
            
            # モデルの初期化
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                torch_dtype=self.dtype,
                device_map="auto",
                quantization_config=quantization_config,
                token=self.huggingface_token,
                low_cpu_mem_usage=True
            )

    @staticmethod
    def parse_llama_guard_output(result):
        """Llama Guardの出力を解析"""
        safety_assessment = result.split("<END CONVERSATION>")[-1].strip()
        lines = [line.strip().lower() for line in safety_assessment.split('\n') if line.strip()]
        
        if not lines:
            return "Error", "No valid output", safety_assessment

        safety_status = next((line for line in lines if line in ['safe', 'unsafe']), None)
        
        if safety_status == 'safe':
            return "Safe", "None", safety_assessment
        elif safety_status == 'unsafe':
            violated_categories = next(
                (lines[i+1] for i, line in enumerate(lines) if line == 'unsafe' and i+1 < len(lines)), 
                "Unspecified"
            )
            return "Unsafe", violated_categories, safety_assessment
        else:
            return "Error", f"Invalid output: {safety_status}", safety_assessment

    @spaces.GPU
    def moderate(self, user_input, assistant_response):
        """モデレーション実行"""
        chat = [
            {"role": "user", "content": user_input},
            {"role": "assistant", "content": assistant_response},
        ]
        
        input_ids = self.tokenizer.apply_chat_template(
            chat, 
            return_tensors="pt"
        ).to(self.device)
        
        with torch.no_grad():
            output = self.model.generate(
                input_ids=input_ids,
                max_new_tokens=200,
                pad_token_id=self.tokenizer.eos_token_id,
                do_sample=False
            )
        
        result = self.tokenizer.decode(output[0], skip_special_tokens=True)
        return self.parse_llama_guard_output(result)

# モデレーターのインスタンス作成
moderator = LlamaGuardModeration()

# Gradio インターフェースの設定
iface = gr.Interface(
    fn=moderator.moderate,
    inputs=[
        gr.Textbox(lines=3, label="User Input"),
        gr.Textbox(lines=3, label="Assistant Response")
    ],
    outputs=[
        gr.Textbox(label="Safety Status"),
        gr.Textbox(label="Violated Categories"),
        gr.Textbox(label="Raw Output")
    ],
    title="Llama Guard Moderation",
    description="Enter a user input and an assistant response to check for content moderation."
)

if __name__ == "__main__":
    iface.launch()