Spaces:
Running
on
Zero
Running
on
Zero
tori29umai
commited on
Commit
•
aa2fbfe
1
Parent(s):
607a0f5
Update
Browse files
app.py
CHANGED
@@ -31,7 +31,6 @@ def preprocess_image(image):
|
|
31 |
image = image.astype(np.float32)
|
32 |
return image
|
33 |
|
34 |
-
|
35 |
class webui:
|
36 |
def __init__(self):
|
37 |
self.demo = gr.Blocks()
|
@@ -41,39 +40,36 @@ class webui:
|
|
41 |
print("Hugging Faceからモデルをダウンロード中")
|
42 |
onnx_path = hf_hub_download(model_id, "model.onnx")
|
43 |
csv_path = hf_hub_download(model_id, "selected_tags.csv")
|
44 |
-
ort_sess = ort.InferenceSession(onnx_path)
|
45 |
-
|
46 |
-
print("ONNXモデルを実行中")
|
47 |
-
print(f"ONNXモデルのパス: {onnx_path}")
|
48 |
|
|
|
49 |
image = Image.open(image_path)
|
50 |
image = image.convert("RGB") if image.mode != "RGB" else image
|
51 |
image = preprocess_image(image)
|
|
|
|
|
|
|
|
|
52 |
|
53 |
with open(csv_path, "r", encoding="utf-8") as f:
|
54 |
reader = csv.reader(f)
|
55 |
-
|
56 |
rows = list(reader)
|
|
|
57 |
rating_tags = [row[1] for row in rows if row[2] == "9"]
|
58 |
character_tags = [row[1] for row in rows if row[2] == "4"]
|
59 |
general_tags = [row[1] for row in rows if row[2] == "0"]
|
60 |
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0] # ONNXモデルからの出力
|
64 |
-
|
65 |
thresh = 0.35
|
66 |
-
|
67 |
# NSFW/SFW判定
|
68 |
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
|
69 |
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
|
70 |
max_sfw_score = tag_confidences.get("general", 0)
|
71 |
-
NSFW_flag =
|
72 |
-
|
73 |
-
if max_nsfw_score > max_sfw_score:
|
74 |
-
NSFW_flag = "NSFWの可能性が高いです"
|
75 |
-
else:
|
76 |
-
NSFW_flag = "SFWの可能性が高いです"
|
77 |
|
78 |
# 版権キャラクターの可能性を評価
|
79 |
character_tags_with_probs = []
|
@@ -85,50 +81,12 @@ class webui:
|
|
85 |
prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
|
86 |
character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
|
87 |
|
88 |
-
IP_flag =
|
89 |
-
if character_tags_with_probs:
|
90 |
-
IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります"
|
91 |
-
else:
|
92 |
-
IP_flag = "版権キャラクターの可能性が低いと思われます"
|
93 |
|
94 |
# タグを生成
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
general_tag_text = ""
|
99 |
-
character_tag_text = ""
|
100 |
-
remove_underscore = True
|
101 |
-
caption_separator = ", "
|
102 |
-
general_threshold = 0.35
|
103 |
-
character_threshold = 0.35
|
104 |
-
|
105 |
-
for i, p in enumerate(prob[4:]):
|
106 |
-
if i < len(general_tags) and p >= general_threshold:
|
107 |
-
tag_name = general_tags[i]
|
108 |
-
if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
|
109 |
-
tag_name = tag_name.replace("_", " ")
|
110 |
-
|
111 |
-
if tag_name not in undesired_tags:
|
112 |
-
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
113 |
-
general_tag_text += caption_separator + tag_name
|
114 |
-
combined_tags.append(tag_name)
|
115 |
-
elif i >= len(general_tags) and p >= character_threshold:
|
116 |
-
tag_name = character_tags[i - len(general_tags)]
|
117 |
-
if remove_underscore and len(tag_name) > 3:
|
118 |
-
tag_name = tag_name.replace("_", " ")
|
119 |
-
|
120 |
-
if tag_name not in undesired_tags:
|
121 |
-
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
122 |
-
character_tag_text += caption_separator + tag_name
|
123 |
-
combined_tags.append(tag_name)
|
124 |
-
|
125 |
-
# 先頭のカンマを取る
|
126 |
-
if len(general_tag_text) > 0:
|
127 |
-
general_tag_text = general_tag_text[len(caption_separator) :]
|
128 |
-
if len(character_tag_text) > 0:
|
129 |
-
character_tag_text = character_tag_text[len(caption_separator) :]
|
130 |
-
tag_text = caption_separator.join(combined_tags)
|
131 |
-
|
132 |
return NSFW_flag, IP_flag, tag_text
|
133 |
|
134 |
def launch(self):
|
@@ -136,7 +94,7 @@ class webui:
|
|
136 |
with gr.Row():
|
137 |
with gr.Column():
|
138 |
input_image = gr.Image(type='filepath', label="Analysis Image")
|
139 |
-
model_id = gr.Textbox(label="
|
140 |
output_0 = gr.Textbox(label="NSFW Flag")
|
141 |
output_1 = gr.Textbox(label="IP Flag")
|
142 |
output_2 = gr.Textbox(label="Tags")
|
@@ -148,8 +106,8 @@ class webui:
|
|
148 |
outputs=[output_0, output_1, output_2]
|
149 |
)
|
150 |
|
151 |
-
self.demo.launch()
|
152 |
|
153 |
if __name__ == "__main__":
|
154 |
ui = webui()
|
155 |
-
ui.launch()
|
|
|
31 |
image = image.astype(np.float32)
|
32 |
return image
|
33 |
|
|
|
34 |
class webui:
|
35 |
def __init__(self):
|
36 |
self.demo = gr.Blocks()
|
|
|
40 |
print("Hugging Faceからモデルをダウンロード中")
|
41 |
onnx_path = hf_hub_download(model_id, "model.onnx")
|
42 |
csv_path = hf_hub_download(model_id, "selected_tags.csv")
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
# ONNXモデルとCSVファイルの読み込み
|
45 |
image = Image.open(image_path)
|
46 |
image = image.convert("RGB") if image.mode != "RGB" else image
|
47 |
image = preprocess_image(image)
|
48 |
+
img = np.array([image])
|
49 |
+
|
50 |
+
ort_sess = ort.InferenceSession(onnx_path) # セッションの生成をここで行う
|
51 |
+
prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0]
|
52 |
|
53 |
with open(csv_path, "r", encoding="utf-8") as f:
|
54 |
reader = csv.reader(f)
|
55 |
+
next(reader) # ヘッダーをスキップ
|
56 |
rows = list(reader)
|
57 |
+
|
58 |
rating_tags = [row[1] for row in rows if row[2] == "9"]
|
59 |
character_tags = [row[1] for row in rows if row[2] == "4"]
|
60 |
general_tags = [row[1] for row in rows if row[2] == "0"]
|
61 |
|
62 |
+
# タグと評価
|
63 |
+
NSFW_flag, IP_flag, tag_text = self.evaluate_tags(prob, rating_tags, character_tags, general_tags)
|
64 |
+
return NSFW_flag, IP_flag, tag_text
|
65 |
|
66 |
+
def evaluate_tags(self, prob, rating_tags, character_tags, general_tags):
|
|
|
|
|
67 |
thresh = 0.35
|
|
|
68 |
# NSFW/SFW判定
|
69 |
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
|
70 |
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
|
71 |
max_sfw_score = tag_confidences.get("general", 0)
|
72 |
+
NSFW_flag = "NSFWの可能性が高いです" if max_nsfw_score > max_sfw_score else "SFWの可能性が高いです"
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
# 版権キャラクターの可能性を評価
|
75 |
character_tags_with_probs = []
|
|
|
81 |
prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
|
82 |
character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
|
83 |
|
84 |
+
IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります" if character_tags_with_probs else "版権キャラクターの可能性が低いと思われます"
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# タグを生成
|
87 |
+
general_tag_text = ", ".join([general_tags[i] for i in range(len(general_tags)) if prob[i] >= thresh])
|
88 |
+
character_tag_text = ", ".join([character_tags[i - len(general_tags)] for i in range(len(general_tags), len(prob)) if prob[i] >= thresh])
|
89 |
+
tag_text = f"{general_tag_text}, {character_tag_text}" if character_tag_text else general_tag_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
return NSFW_flag, IP_flag, tag_text
|
91 |
|
92 |
def launch(self):
|
|
|
94 |
with gr.Row():
|
95 |
with gr.Column():
|
96 |
input_image = gr.Image(type='filepath', label="Analysis Image")
|
97 |
+
model_id = gr.Textbox(label="Model ID", value="SmilingWolf/wd-vit-tagger-v3")
|
98 |
output_0 = gr.Textbox(label="NSFW Flag")
|
99 |
output_1 = gr.Textbox(label="IP Flag")
|
100 |
output_2 = gr.Textbox(label="Tags")
|
|
|
106 |
outputs=[output_0, output_1, output_2]
|
107 |
)
|
108 |
|
109 |
+
self.demo.launch(share=True) # 公開リンクを設定
|
110 |
|
111 |
if __name__ == "__main__":
|
112 |
ui = webui()
|
113 |
+
ui.launch()
|