tori29umai commited on
Commit
aa2fbfe
1 Parent(s): 607a0f5
Files changed (1) hide show
  1. app.py +19 -61
app.py CHANGED
@@ -31,7 +31,6 @@ def preprocess_image(image):
31
  image = image.astype(np.float32)
32
  return image
33
 
34
-
35
  class webui:
36
  def __init__(self):
37
  self.demo = gr.Blocks()
@@ -41,39 +40,36 @@ class webui:
41
  print("Hugging Faceからモデルをダウンロード中")
42
  onnx_path = hf_hub_download(model_id, "model.onnx")
43
  csv_path = hf_hub_download(model_id, "selected_tags.csv")
44
- ort_sess = ort.InferenceSession(onnx_path)
45
-
46
- print("ONNXモデルを実行中")
47
- print(f"ONNXモデルのパス: {onnx_path}")
48
 
 
49
  image = Image.open(image_path)
50
  image = image.convert("RGB") if image.mode != "RGB" else image
51
  image = preprocess_image(image)
 
 
 
 
52
 
53
  with open(csv_path, "r", encoding="utf-8") as f:
54
  reader = csv.reader(f)
55
- header = next(reader)
56
  rows = list(reader)
 
57
  rating_tags = [row[1] for row in rows if row[2] == "9"]
58
  character_tags = [row[1] for row in rows if row[2] == "4"]
59
  general_tags = [row[1] for row in rows if row[2] == "0"]
60
 
 
 
 
61
 
62
- img = np.array([image])
63
- prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0] # ONNXモデルからの出力
64
-
65
  thresh = 0.35
66
-
67
  # NSFW/SFW判定
68
  tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
69
  max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
70
  max_sfw_score = tag_confidences.get("general", 0)
71
- NSFW_flag = None
72
-
73
- if max_nsfw_score > max_sfw_score:
74
- NSFW_flag = "NSFWの可能性が高いです"
75
- else:
76
- NSFW_flag = "SFWの可能性が高いです"
77
 
78
  # 版権キャラクターの可能性を評価
79
  character_tags_with_probs = []
@@ -85,50 +81,12 @@ class webui:
85
  prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
86
  character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
87
 
88
- IP_flag = None
89
- if character_tags_with_probs:
90
- IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります"
91
- else:
92
- IP_flag = "版権キャラクターの可能性が低いと思われます"
93
 
94
  # タグを生成
95
- tag_freq = {}
96
- undesired_tags = []
97
- combined_tags = []
98
- general_tag_text = ""
99
- character_tag_text = ""
100
- remove_underscore = True
101
- caption_separator = ", "
102
- general_threshold = 0.35
103
- character_threshold = 0.35
104
-
105
- for i, p in enumerate(prob[4:]):
106
- if i < len(general_tags) and p >= general_threshold:
107
- tag_name = general_tags[i]
108
- if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
109
- tag_name = tag_name.replace("_", " ")
110
-
111
- if tag_name not in undesired_tags:
112
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
113
- general_tag_text += caption_separator + tag_name
114
- combined_tags.append(tag_name)
115
- elif i >= len(general_tags) and p >= character_threshold:
116
- tag_name = character_tags[i - len(general_tags)]
117
- if remove_underscore and len(tag_name) > 3:
118
- tag_name = tag_name.replace("_", " ")
119
-
120
- if tag_name not in undesired_tags:
121
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
122
- character_tag_text += caption_separator + tag_name
123
- combined_tags.append(tag_name)
124
-
125
- # 先頭のカンマを取る
126
- if len(general_tag_text) > 0:
127
- general_tag_text = general_tag_text[len(caption_separator) :]
128
- if len(character_tag_text) > 0:
129
- character_tag_text = character_tag_text[len(caption_separator) :]
130
- tag_text = caption_separator.join(combined_tags)
131
-
132
  return NSFW_flag, IP_flag, tag_text
133
 
134
  def launch(self):
@@ -136,7 +94,7 @@ class webui:
136
  with gr.Row():
137
  with gr.Column():
138
  input_image = gr.Image(type='filepath', label="Analysis Image")
139
- model_id = gr.Textbox(label="MODEL ID", value="SmilingWolf/wd-vit-tagger-v3")
140
  output_0 = gr.Textbox(label="NSFW Flag")
141
  output_1 = gr.Textbox(label="IP Flag")
142
  output_2 = gr.Textbox(label="Tags")
@@ -148,8 +106,8 @@ class webui:
148
  outputs=[output_0, output_1, output_2]
149
  )
150
 
151
- self.demo.launch()
152
 
153
  if __name__ == "__main__":
154
  ui = webui()
155
- ui.launch()
 
31
  image = image.astype(np.float32)
32
  return image
33
 
 
34
  class webui:
35
  def __init__(self):
36
  self.demo = gr.Blocks()
 
40
  print("Hugging Faceからモデルをダウンロード中")
41
  onnx_path = hf_hub_download(model_id, "model.onnx")
42
  csv_path = hf_hub_download(model_id, "selected_tags.csv")
 
 
 
 
43
 
44
+ # ONNXモデルとCSVファイルの読み込み
45
  image = Image.open(image_path)
46
  image = image.convert("RGB") if image.mode != "RGB" else image
47
  image = preprocess_image(image)
48
+ img = np.array([image])
49
+
50
+ ort_sess = ort.InferenceSession(onnx_path) # セッションの生成をここで行う
51
+ prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0]
52
 
53
  with open(csv_path, "r", encoding="utf-8") as f:
54
  reader = csv.reader(f)
55
+ next(reader) # ヘッダーをスキップ
56
  rows = list(reader)
57
+
58
  rating_tags = [row[1] for row in rows if row[2] == "9"]
59
  character_tags = [row[1] for row in rows if row[2] == "4"]
60
  general_tags = [row[1] for row in rows if row[2] == "0"]
61
 
62
+ # タグと評価
63
+ NSFW_flag, IP_flag, tag_text = self.evaluate_tags(prob, rating_tags, character_tags, general_tags)
64
+ return NSFW_flag, IP_flag, tag_text
65
 
66
+ def evaluate_tags(self, prob, rating_tags, character_tags, general_tags):
 
 
67
  thresh = 0.35
 
68
  # NSFW/SFW判定
69
  tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
70
  max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
71
  max_sfw_score = tag_confidences.get("general", 0)
72
+ NSFW_flag = "NSFWの可能性が高いです" if max_nsfw_score > max_sfw_score else "SFWの可能性が高いです"
 
 
 
 
 
73
 
74
  # 版権キャラクターの可能性を評価
75
  character_tags_with_probs = []
 
81
  prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
82
  character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
83
 
84
+ IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります" if character_tags_with_probs else "版権キャラクターの可能性が低いと思われます"
 
 
 
 
85
 
86
  # タグを生成
87
+ general_tag_text = ", ".join([general_tags[i] for i in range(len(general_tags)) if prob[i] >= thresh])
88
+ character_tag_text = ", ".join([character_tags[i - len(general_tags)] for i in range(len(general_tags), len(prob)) if prob[i] >= thresh])
89
+ tag_text = f"{general_tag_text}, {character_tag_text}" if character_tag_text else general_tag_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  return NSFW_flag, IP_flag, tag_text
91
 
92
  def launch(self):
 
94
  with gr.Row():
95
  with gr.Column():
96
  input_image = gr.Image(type='filepath', label="Analysis Image")
97
+ model_id = gr.Textbox(label="Model ID", value="SmilingWolf/wd-vit-tagger-v3")
98
  output_0 = gr.Textbox(label="NSFW Flag")
99
  output_1 = gr.Textbox(label="IP Flag")
100
  output_2 = gr.Textbox(label="Tags")
 
106
  outputs=[output_0, output_1, output_2]
107
  )
108
 
109
+ self.demo.launch(share=True) # 公開リンクを設定
110
 
111
  if __name__ == "__main__":
112
  ui = webui()
113
+ ui.launch()