File size: 2,907 Bytes
46358a2 e2fac8d 6293678 29e0785 6293678 46358a2 e2fac8d 0a17bfe 3c1404f 0a17bfe 3c1404f 0a17bfe 3c1404f 0a17bfe 3c1404f 0a17bfe 3c1404f 0a17bfe 3c1404f 0a17bfe 3c1404f 0a17bfe 142b81d e369c4b 3c1404f e2fac8d ca0aa0f 3c1404f ca0aa0f 3c1404f ca0aa0f 3c1404f 0a17bfe e2fac8d 83fe2ae e2fac8d ca0aa0f e2fac8d 29e0785 e2fac8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import gradio as gr
import spaces
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
if not huggingface_token:
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
model_id = "meta-llama/Llama-Guard-3-8B-INT8"
dtype = torch.bfloat16
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
def parse_llama_guard_output(result):
# "<END CONVERSATION>" 以降の部分を抽出
safety_assessment = result.split("<END CONVERSATION>")[-1].strip()
# 行ごとに分割して処理
lines = [line.strip().lower() for line in safety_assessment.split('\n') if line.strip()]
if not lines:
return "Error", "No valid output", safety_assessment
# "safe" または "unsafe" を探す
safety_status = next((line for line in lines if line in ['safe', 'unsafe']), None)
if safety_status == 'safe':
return "Safe", "None", safety_assessment
elif safety_status == 'unsafe':
# "unsafe" の次の行を違反カテゴリーとして扱う
violated_categories = next((lines[i+1] for i, line in enumerate(lines) if line == 'unsafe' and i+1 < len(lines)), "Unspecified")
return "Unsafe", violated_categories, safety_assessment
else:
return "Error", f"Invalid output: {safety_status}", safety_assessment
@spaces.GPU
def moderate(user_input, assistant_response):
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
device_map="auto",
quantization_config=quantization_config,
token=huggingface_token,
low_cpu_mem_usage=True
)
chat = [
{"role": "user", "content": user_input},
{"role": "assistant", "content": assistant_response},
]
input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
input_ids=input_ids,
max_new_tokens=200,
pad_token_id=tokenizer.eos_token_id,
do_sample=False
)
result = tokenizer.decode(output[0], skip_special_tokens=True)
return parse_llama_guard_output(result)
iface = gr.Interface(
fn=moderate,
inputs=[
gr.Textbox(lines=3, label="User Input"),
gr.Textbox(lines=3, label="Assistant Response")
],
outputs=[
gr.Textbox(label="Safety Status"),
gr.Textbox(label="Violated Categories"),
gr.Textbox(label="Raw Output")
],
title="Llama Guard Moderation",
description="Enter a user input and an assistant response to check for content moderation."
)
if __name__ == "__main__":
iface.launch() |