File size: 3,742 Bytes
46358a2
e2fac8d
39f6145
6293678
29e0785
6293678
66e2112
 
 
 
 
e3ffcc7
66e2112
 
 
 
 
 
 
 
 
46358a2
66e2112
 
 
 
 
 
 
 
 
e3ffcc7
66e2112
 
 
 
 
 
 
e2fac8d
66e2112
 
 
 
 
 
 
 
e2fac8d
66e2112
 
 
 
 
 
 
 
 
 
 
 
0a17bfe
66e2112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a17bfe
66e2112
 
e2fac8d
66e2112
e2fac8d
66e2112
e2fac8d
 
 
 
ca0aa0f
 
 
 
 
e2fac8d
 
 
29e0785
e2fac8d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import spaces

class LlamaGuardModeration:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_id = "meta-llama/Llama-Guard-3-8B"  # モデルIDを変更
        self.dtype = torch.bfloat16
        
        # HuggingFace tokenの取得
        self.huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
        if not self.huggingface_token:
            raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
        
        # モデルの初期化
        self.initialize_model()

    def initialize_model(self):
        """モデルとトークナイザーの初期化"""
        if self.model is None:
            # トークナイザーの初期化
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_id, 
                token=self.huggingface_token
            )
            
            # モデルの初期化
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_id,
                torch_dtype=self.dtype,
                device_map="auto",
                token=self.huggingface_token,
                low_cpu_mem_usage=True
            )

    @staticmethod
    def parse_llama_guard_output(result):
        """Llama Guardの出力を解析"""
        safety_assessment = result.split("<END CONVERSATION>")[-1].strip()
        lines = [line.strip().lower() for line in safety_assessment.split('\n') if line.strip()]
        
        if not lines:
            return "Error", "No valid output", safety_assessment

        safety_status = next((line for line in lines if line in ['safe', 'unsafe']), None)
        
        if safety_status == 'safe':
            return "Safe", "None", safety_assessment
        elif safety_status == 'unsafe':
            violated_categories = next(
                (lines[i+1] for i, line in enumerate(lines) if line == 'unsafe' and i+1 < len(lines)), 
                "Unspecified"
            )
            return "Unsafe", violated_categories, safety_assessment
        else:
            return "Error", f"Invalid output: {safety_status}", safety_assessment

    @spaces.GPU
    def moderate(self, user_input, assistant_response):
        """モデレーション実行"""
        chat = [
            {"role": "user", "content": user_input},
            {"role": "assistant", "content": assistant_response},
        ]
        
        input_ids = self.tokenizer.apply_chat_template(
            chat, 
            return_tensors="pt"
        ).to(self.device)
        
        with torch.no_grad():
            output = self.model.generate(
                input_ids=input_ids,
                max_new_tokens=200,
                pad_token_id=self.tokenizer.eos_token_id,
                do_sample=False
            )
        
        result = self.tokenizer.decode(output[0], skip_special_tokens=True)
        return self.parse_llama_guard_output(result)

# モデレーターのインスタンス作成
moderator = LlamaGuardModeration()

# Gradio インターフェースの設定
iface = gr.Interface(
    fn=moderator.moderate,
    inputs=[
        gr.Textbox(lines=3, label="User Input"),
        gr.Textbox(lines=3, label="Assistant Response")
    ],
    outputs=[
        gr.Textbox(label="Safety Status"),
        gr.Textbox(label="Violated Categories"),
        gr.Textbox(label="Raw Output")
    ],
    title="Llama Guard Moderation",
    description="Enter a user input and an assistant response to check for content moderation."
)

if __name__ == "__main__":
    iface.launch()