File size: 3,742 Bytes
46358a2 e2fac8d 39f6145 6293678 29e0785 6293678 66e2112 2cea941 66e2112 46358a2 66e2112 39f6145 66e2112 e2fac8d 66e2112 e2fac8d 66e2112 0a17bfe 66e2112 0a17bfe 66e2112 e2fac8d 66e2112 e2fac8d 66e2112 e2fac8d ca0aa0f e2fac8d 29e0785 e2fac8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import spaces
class LlamaGuardModeration:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_id = "meta-llama/Llama-Guard-3-8B"
self.dtype = torch.bfloat16
# HuggingFace tokenの取得
self.huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
if not self.huggingface_token:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
# モデルの初期化
self.initialize_model()
def initialize_model(self):
"""モデルとトークナイザーの初期化"""
if self.model is None:
# トークナイザーの初期化
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id,
token=self.huggingface_token
)
# モデルの初期化(bitsandbytesなし)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
torch_dtype=self.dtype,
device_map="auto",
token=self.huggingface_token,
low_cpu_mem_usage=True
)
@staticmethod
def parse_llama_guard_output(result):
"""Llama Guardの出力を解析"""
safety_assessment = result.split("<END CONVERSATION>")[-1].strip()
lines = [line.strip().lower() for line in safety_assessment.split('\n') if line.strip()]
if not lines:
return "Error", "No valid output", safety_assessment
safety_status = next((line for line in lines if line in ['safe', 'unsafe']), None)
if safety_status == 'safe':
return "Safe", "None", safety_assessment
elif safety_status == 'unsafe':
violated_categories = next(
(lines[i+1] for i, line in enumerate(lines) if line == 'unsafe' and i+1 < len(lines)),
"Unspecified"
)
return "Unsafe", violated_categories, safety_assessment
else:
return "Error", f"Invalid output: {safety_status}", safety_assessment
@spaces.GPU
def moderate(self, user_input, assistant_response):
"""モデレーション実行"""
chat = [
{"role": "user", "content": user_input},
{"role": "assistant", "content": assistant_response},
]
input_ids = self.tokenizer.apply_chat_template(
chat,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=200,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=False
)
result = self.tokenizer.decode(output[0], skip_special_tokens=True)
return self.parse_llama_guard_output(result)
# モデレーターのインスタンス作成
moderator = LlamaGuardModeration()
# Gradio インターフェースの設定
iface = gr.Interface(
fn=moderator.moderate,
inputs=[
gr.Textbox(lines=3, label="User Input"),
gr.Textbox(lines=3, label="Assistant Response")
],
outputs=[
gr.Textbox(label="Safety Status"),
gr.Textbox(label="Violated Categories"),
gr.Textbox(label="Raw Output")
],
title="Llama Guard Moderation",
description="Enter a user input and an assistant response to check for content moderation."
)
if __name__ == "__main__":
iface.launch() |