schroneko commited on
Commit
66e2112
·
verified ·
1 Parent(s): e369c4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -57
app.py CHANGED
@@ -4,70 +4,99 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  import gradio as gr
5
  import spaces
6
 
7
- huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
8
- if not huggingface_token:
9
- raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- model_id = "meta-llama/Llama-Guard-3-8B-INT8"
12
- dtype = torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
 
 
 
 
 
 
 
15
 
16
- def parse_llama_guard_output(result):
17
- # "<END CONVERSATION>" 以降の部分を抽出
18
- safety_assessment = result.split("<END CONVERSATION>")[-1].strip()
19
-
20
- # 行ごとに分割して処理
21
- lines = [line.strip().lower() for line in safety_assessment.split('\n') if line.strip()]
22
-
23
- if not lines:
24
- return "Error", "No valid output", safety_assessment
 
 
 
25
 
26
- # "safe" または "unsafe" を探す
27
- safety_status = next((line for line in lines if line in ['safe', 'unsafe']), None)
28
-
29
- if safety_status == 'safe':
30
- return "Safe", "None", safety_assessment
31
- elif safety_status == 'unsafe':
32
- # "unsafe" の次の行を違反カテゴリーとして扱う
33
- violated_categories = next((lines[i+1] for i, line in enumerate(lines) if line == 'unsafe' and i+1 < len(lines)), "Unspecified")
34
- return "Unsafe", violated_categories, safety_assessment
35
- else:
36
- return "Error", f"Invalid output: {safety_status}", safety_assessment
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- @spaces.GPU
39
- def moderate(user_input, assistant_response):
40
- device = "cuda" if torch.cuda.is_available() else "cpu"
41
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=huggingface_token)
42
- model = AutoModelForCausalLM.from_pretrained(
43
- model_id,
44
- torch_dtype=dtype,
45
- device_map="auto",
46
- quantization_config=quantization_config,
47
- token=huggingface_token,
48
- low_cpu_mem_usage=True
49
- )
50
-
51
- chat = [
52
- {"role": "user", "content": user_input},
53
- {"role": "assistant", "content": assistant_response},
54
- ]
55
- input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
56
-
57
- with torch.no_grad():
58
- output = model.generate(
59
- input_ids=input_ids,
60
- max_new_tokens=200,
61
- pad_token_id=tokenizer.eos_token_id,
62
- do_sample=False
63
- )
64
-
65
- result = tokenizer.decode(output[0], skip_special_tokens=True)
66
-
67
- return parse_llama_guard_output(result)
68
 
 
69
  iface = gr.Interface(
70
- fn=moderate,
71
  inputs=[
72
  gr.Textbox(lines=3, label="User Input"),
73
  gr.Textbox(lines=3, label="Assistant Response")
 
4
  import gradio as gr
5
  import spaces
6
 
7
+ class LlamaGuardModeration:
8
+ def __init__(self):
9
+ self.model = None
10
+ self.tokenizer = None
11
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ self.model_id = "meta-llama/Llama-Guard-3-8B-INT8"
13
+ self.dtype = torch.bfloat16
14
+
15
+ # HuggingFace tokenの取得
16
+ self.huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
17
+ if not self.huggingface_token:
18
+ raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
19
+
20
+ # モデルの初期化
21
+ self.initialize_model()
22
 
23
+ def initialize_model(self):
24
+ """モデルとトークナイザーの初期化"""
25
+ if self.model is None:
26
+ # quantization_configの設定
27
+ quantization_config = BitsAndBytesConfig(
28
+ load_in_8bit=True,
29
+ bnb_4bit_compute_dtype=self.dtype
30
+ )
31
+
32
+ # トークナイザーの初期化
33
+ self.tokenizer = AutoTokenizer.from_pretrained(
34
+ self.model_id,
35
+ token=self.huggingface_token
36
+ )
37
+
38
+ # モデルの初期化
39
+ self.model = AutoModelForCausalLM.from_pretrained(
40
+ self.model_id,
41
+ torch_dtype=self.dtype,
42
+ device_map="auto",
43
+ quantization_config=quantization_config,
44
+ token=self.huggingface_token,
45
+ low_cpu_mem_usage=True
46
+ )
47
 
48
+ @staticmethod
49
+ def parse_llama_guard_output(result):
50
+ """Llama Guardの出力を解析"""
51
+ safety_assessment = result.split("<END CONVERSATION>")[-1].strip()
52
+ lines = [line.strip().lower() for line in safety_assessment.split('\n') if line.strip()]
53
+
54
+ if not lines:
55
+ return "Error", "No valid output", safety_assessment
56
 
57
+ safety_status = next((line for line in lines if line in ['safe', 'unsafe']), None)
58
+
59
+ if safety_status == 'safe':
60
+ return "Safe", "None", safety_assessment
61
+ elif safety_status == 'unsafe':
62
+ violated_categories = next(
63
+ (lines[i+1] for i, line in enumerate(lines) if line == 'unsafe' and i+1 < len(lines)),
64
+ "Unspecified"
65
+ )
66
+ return "Unsafe", violated_categories, safety_assessment
67
+ else:
68
+ return "Error", f"Invalid output: {safety_status}", safety_assessment
69
 
70
+ @spaces.GPU
71
+ def moderate(self, user_input, assistant_response):
72
+ """モデレーション実行"""
73
+ chat = [
74
+ {"role": "user", "content": user_input},
75
+ {"role": "assistant", "content": assistant_response},
76
+ ]
77
+
78
+ input_ids = self.tokenizer.apply_chat_template(
79
+ chat,
80
+ return_tensors="pt"
81
+ ).to(self.device)
82
+
83
+ with torch.no_grad():
84
+ output = self.model.generate(
85
+ input_ids=input_ids,
86
+ max_new_tokens=200,
87
+ pad_token_id=self.tokenizer.eos_token_id,
88
+ do_sample=False
89
+ )
90
+
91
+ result = self.tokenizer.decode(output[0], skip_special_tokens=True)
92
+ return self.parse_llama_guard_output(result)
93
 
94
+ # モデレーターのインスタンス作成
95
+ moderator = LlamaGuardModeration()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ # Gradio インターフェースの設定
98
  iface = gr.Interface(
99
+ fn=moderator.moderate,
100
  inputs=[
101
  gr.Textbox(lines=3, label="User Input"),
102
  gr.Textbox(lines=3, label="Assistant Response")