avans06 commited on
Commit
bb844f7
1 Parent(s): e444e5b

WaifuDiffusion Tagger now supports uploading multiple images

Browse files

1. Changed the interface’s "Image" field to "Gallery," now supporting multi-image uploads.

2. Added a "Download ZIP" option for results, allowing all reverse-engineered TXT files from the images to be compressed into a ZIP file for a single download.

3. Other fields in the "Output" section currently only display results for the first image.

4. The "Examples" feature is temporarily disabled, as it encounters issues when used with the Gallery view, and a solution has not yet been found.

5. Add a checkbox: "Merge characters into the string output," with the default set to checked.

6. Add two additional tag fields, Prepend and Append. The content is comma-separated and will automatically be added to the string output.

7. Upgrade the project's Gradio version to 5.6.0

8. Add an Upload Images button to allow adding images to the gallery at any time.

9. Add a Remove Selected Image button to remove the currently selected image from the gallery when clicked.

10. Add support for automatically displaying the image's output strings, rating, characters, and tags when selecting an image from the gallery.

Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +197 -67
  3. requirements.txt +1 -1
  4. webui.bat +73 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 💬
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.4.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.6.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -7,6 +7,10 @@ import numpy as np
7
  import onnxruntime as rt
8
  import pandas as pd
9
  from PIL import Image
 
 
 
 
10
 
11
  TITLE = "WaifuDiffusion Tagger"
12
  DESCRIPTION = """
@@ -56,6 +60,8 @@ kaomojis = [
56
  "||_||",
57
  ]
58
 
 
 
59
 
60
  def parse_args() -> argparse.Namespace:
61
  parser = argparse.ArgumentParser()
@@ -130,7 +136,9 @@ class Predictor:
130
  self.last_loaded_repo = model_repo
131
  self.model = model
132
 
133
- def prepare_image(self, image):
 
 
134
  target_size = self.model_target_size
135
 
136
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
@@ -161,61 +169,149 @@ class Predictor:
161
 
162
  return np.expand_dims(image_array, axis=0)
163
 
 
 
 
 
 
 
 
164
  def predict(
165
  self,
166
- image,
167
  model_repo,
168
  general_thresh,
169
  general_mcut_enabled,
170
  character_thresh,
171
  character_mcut_enabled,
 
 
 
172
  ):
173
  self.load_model(model_repo)
174
-
175
- image = self.prepare_image(image)
176
-
177
- input_name = self.model.get_inputs()[0].name
178
- label_name = self.model.get_outputs()[0].name
179
- preds = self.model.run([label_name], {input_name: image})[0]
180
-
181
- labels = list(zip(self.tag_names, preds[0].astype(float)))
182
-
183
- # First 4 labels are actually ratings: pick one with argmax
184
- ratings_names = [labels[i] for i in self.rating_indexes]
185
- rating = dict(ratings_names)
186
-
187
- # Then we have general tags: pick any where prediction confidence > threshold
188
- general_names = [labels[i] for i in self.general_indexes]
189
-
190
- if general_mcut_enabled:
191
- general_probs = np.array([x[1] for x in general_names])
192
- general_thresh = mcut_threshold(general_probs)
193
-
194
- general_res = [x for x in general_names if x[1] > general_thresh]
195
- general_res = dict(general_res)
196
-
197
- # Everything else is characters: pick any where prediction confidence > threshold
198
- character_names = [labels[i] for i in self.character_indexes]
199
-
200
- if character_mcut_enabled:
201
- character_probs = np.array([x[1] for x in character_names])
202
- character_thresh = mcut_threshold(character_probs)
203
- character_thresh = max(0.15, character_thresh)
204
-
205
- character_res = [x for x in character_names if x[1] > character_thresh]
206
- character_res = dict(character_res)
207
-
208
- sorted_general_strings = sorted(
209
- general_res.items(),
210
- key=lambda x: x[1],
211
- reverse=True,
212
- )
213
- sorted_general_strings = [x[0] for x in sorted_general_strings]
214
- sorted_general_strings = (
215
- ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
216
- )
217
-
218
- return sorted_general_strings, rating, character_res, general_res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
 
221
  def main():
@@ -225,10 +321,10 @@ def main():
225
 
226
  dropdown_list = [
227
  SWINV2_MODEL_DSV3_REPO,
 
228
  CONV_MODEL_DSV3_REPO,
229
  VIT_MODEL_DSV3_REPO,
230
  VIT_LARGE_MODEL_DSV3_REPO,
231
- EVA02_LARGE_MODEL_DSV3_REPO,
232
  MOAT_MODEL_DSV2_REPO,
233
  SWIN_MODEL_DSV2_REPO,
234
  CONV_MODEL_DSV2_REPO,
@@ -244,7 +340,14 @@ def main():
244
  gr.Markdown(value=DESCRIPTION)
245
  with gr.Row():
246
  with gr.Column(variant="panel"):
247
- image = gr.Image(type="pil", image_mode="RGBA", label="Input")
 
 
 
 
 
 
 
248
  model_repo = gr.Dropdown(
249
  dropdown_list,
250
  value=SWINV2_MODEL_DSV3_REPO,
@@ -278,27 +381,40 @@ def main():
278
  label="Use MCut threshold",
279
  scale=1,
280
  )
 
 
 
 
 
 
 
 
 
281
  with gr.Row():
282
  clear = gr.ClearButton(
283
  components=[
284
- image,
285
  model_repo,
286
  general_thresh,
287
  general_mcut_enabled,
288
  character_thresh,
289
  character_mcut_enabled,
 
 
 
290
  ],
291
  variant="secondary",
292
  size="lg",
293
  )
294
- submit = gr.Button(value="Submit", variant="primary", size="lg")
295
  with gr.Column(variant="panel"):
296
- sorted_general_strings = gr.Textbox(label="Output (string)")
 
297
  rating = gr.Label(label="Rating")
298
  character_res = gr.Label(label="Output (characters)")
299
  general_res = gr.Label(label="Output (tags)")
300
  clear.add(
301
  [
 
302
  sorted_general_strings,
303
  rating,
304
  character_res,
@@ -306,33 +422,47 @@ def main():
306
  ]
307
  )
308
 
 
 
 
 
 
 
 
 
309
  submit.click(
310
  predictor.predict,
311
  inputs=[
312
- image,
313
- model_repo,
314
- general_thresh,
315
- general_mcut_enabled,
316
- character_thresh,
317
- character_mcut_enabled,
318
- ],
319
- outputs=[sorted_general_strings, rating, character_res, general_res],
320
- )
321
-
322
- gr.Examples(
323
- [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
324
- inputs=[
325
- image,
326
  model_repo,
327
  general_thresh,
328
  general_mcut_enabled,
329
  character_thresh,
330
  character_mcut_enabled,
 
 
 
331
  ],
 
332
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  demo.queue(max_size=10)
335
- demo.launch()
336
 
337
 
338
  if __name__ == "__main__":
 
7
  import onnxruntime as rt
8
  import pandas as pd
9
  from PIL import Image
10
+ import traceback
11
+ import tempfile
12
+ import zipfile
13
+ from datetime import datetime
14
 
15
  TITLE = "WaifuDiffusion Tagger"
16
  DESCRIPTION = """
 
60
  "||_||",
61
  ]
62
 
63
+ tag_results = {}
64
+
65
 
66
  def parse_args() -> argparse.Namespace:
67
  parser = argparse.ArgumentParser()
 
136
  self.last_loaded_repo = model_repo
137
  self.model = model
138
 
139
+ def prepare_image(self, path):
140
+ image = Image.open(path)
141
+ image = image.convert("RGBA")
142
  target_size = self.model_target_size
143
 
144
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
 
169
 
170
  return np.expand_dims(image_array, axis=0)
171
 
172
+ def create_file(self, text: str, directory: str, fileName: str) -> str:
173
+ # Write the text to a file
174
+ with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
175
+ file.write(text)
176
+
177
+ return file.name
178
+
179
  def predict(
180
  self,
181
+ gallery,
182
  model_repo,
183
  general_thresh,
184
  general_mcut_enabled,
185
  character_thresh,
186
  character_mcut_enabled,
187
+ characters_merge_enabled,
188
+ additional_tags_prepend,
189
+ additional_tags_append,
190
  ):
191
  self.load_model(model_repo)
192
+ # Result
193
+ txt_infos = []
194
+ output_dir = tempfile.mkdtemp()
195
+ if not os.path.exists(output_dir):
196
+ os.makedirs(output_dir)
197
+
198
+ sorted_general_strings = ""
199
+ rating = None
200
+ character_res = None
201
+ general_res = None
202
+
203
+ tag_results.clear()
204
+
205
+ prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
206
+ append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
207
+ if prepend_list and append_list:
208
+ append_list = [item for item in append_list if item not in prepend_list]
209
+
210
+ for idx, value in enumerate(gallery):
211
+ try:
212
+ image_path = value[0]
213
+ image_name = os.path.splitext(os.path.basename(image_path))[0]
214
+
215
+ image = self.prepare_image(image_path)
216
+
217
+ input_name = self.model.get_inputs()[0].name
218
+ label_name = self.model.get_outputs()[0].name
219
+ preds = self.model.run([label_name], {input_name: image})[0]
220
+
221
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
222
+
223
+ # First 4 labels are actually ratings: pick one with argmax
224
+ ratings_names = [labels[i] for i in self.rating_indexes]
225
+ rating = dict(ratings_names)
226
+
227
+ # Then we have general tags: pick any where prediction confidence > threshold
228
+ general_names = [labels[i] for i in self.general_indexes]
229
+
230
+ if general_mcut_enabled:
231
+ general_probs = np.array([x[1] for x in general_names])
232
+ general_thresh = mcut_threshold(general_probs)
233
+
234
+ general_res = [x for x in general_names if x[1] > general_thresh]
235
+ general_res = dict(general_res)
236
+
237
+ # Everything else is characters: pick any where prediction confidence > threshold
238
+ character_names = [labels[i] for i in self.character_indexes]
239
+
240
+ if character_mcut_enabled:
241
+ character_probs = np.array([x[1] for x in character_names])
242
+ character_thresh = mcut_threshold(character_probs)
243
+ character_thresh = max(0.15, character_thresh)
244
+
245
+ character_res = [x for x in character_names if x[1] > character_thresh]
246
+ character_res = dict(character_res)
247
+ character_list = list(character_res.keys())
248
+
249
+ sorted_general_list = sorted(
250
+ general_res.items(),
251
+ key=lambda x: x[1],
252
+ reverse=True,
253
+ )
254
+ sorted_general_list = [x[0] for x in sorted_general_list]
255
+ #Remove values from character_list that already exist in sorted_general_list
256
+ character_list = [item for item in character_list if item not in sorted_general_list]
257
+ #Remove values from sorted_general_list that already exist in prepend_list or append_list
258
+ if prepend_list:
259
+ sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
260
+ if append_list:
261
+ sorted_general_list = [item for item in sorted_general_list if item not in append_list]
262
+
263
+ sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + prepend_list + sorted_general_list + append_list).replace("(", "\(").replace(")", "\)")
264
+
265
+ txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
266
+ txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
267
+
268
+ tag_results[image_path] = { "strings": sorted_general_strings, "rating": rating, "character_res": character_res, "general_res": general_res }
269
+
270
+ except Exception as e:
271
+ print(traceback.format_exc())
272
+ print("Error predict: " + str(e))
273
+
274
+ # Result
275
+ download = []
276
+ if txt_infos is not None and len(txt_infos) > 0:
277
+ downloadZipPath = os.path.join(output_dir, "images-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
278
+ with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
279
+ for info in txt_infos:
280
+ # Get file name from lookup
281
+ taggers_zip.write(info["path"], arcname=info["name"])
282
+ download.append(downloadZipPath)
283
+
284
+ return download, sorted_general_strings, rating, character_res, general_res
285
+
286
+ def get_selection_from_gallery(gallery: list, selected_state: gr.SelectData):
287
+ if not selected_state:
288
+ return selected_state
289
+
290
+ tag_result = { "strings": "", "rating": "", "character_res": "", "general_res": "" }
291
+ if selected_state.value["image"]["path"] in tag_results:
292
+ tag_result = tag_results[selected_state.value["image"]["path"]]
293
+
294
+ return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"]
295
+
296
+ def add_images_to_gallery(gallery: list, images):
297
+ if gallery is None:
298
+ gallery = []
299
+ if not images:
300
+ return gallery
301
+
302
+ # Combine the new images with the existing gallery images
303
+ gallery.extend(images)
304
+ return gallery
305
+
306
+ def remove_image_from_gallery(gallery: list, selected_image: str):
307
+ if not gallery or not selected_image:
308
+ return gallery
309
+
310
+ selected_image = eval(selected_image)
311
+ # Remove the selected image from the gallery
312
+ if selected_image in gallery:
313
+ gallery.remove(selected_image)
314
+ return gallery
315
 
316
 
317
  def main():
 
321
 
322
  dropdown_list = [
323
  SWINV2_MODEL_DSV3_REPO,
324
+ EVA02_LARGE_MODEL_DSV3_REPO,
325
  CONV_MODEL_DSV3_REPO,
326
  VIT_MODEL_DSV3_REPO,
327
  VIT_LARGE_MODEL_DSV3_REPO,
 
328
  MOAT_MODEL_DSV2_REPO,
329
  SWIN_MODEL_DSV2_REPO,
330
  CONV_MODEL_DSV2_REPO,
 
340
  gr.Markdown(value=DESCRIPTION)
341
  with gr.Row():
342
  with gr.Column(variant="panel"):
343
+ with gr.Row():
344
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
345
+ with gr.Row():
346
+ gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Input")
347
+ with gr.Row():
348
+ upload_button = gr.UploadButton("Upload Images", file_types=["image"], file_count="multiple", size="sm")
349
+ remove_button = gr.Button("Remove Selected Image", size="sm")
350
+
351
  model_repo = gr.Dropdown(
352
  dropdown_list,
353
  value=SWINV2_MODEL_DSV3_REPO,
 
381
  label="Use MCut threshold",
382
  scale=1,
383
  )
384
+ with gr.Row():
385
+ characters_merge_enabled = gr.Checkbox(
386
+ value=True,
387
+ label="Merge characters into the string output",
388
+ scale=1,
389
+ )
390
+ with gr.Row():
391
+ additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
392
+ additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
393
  with gr.Row():
394
  clear = gr.ClearButton(
395
  components=[
396
+ gallery,
397
  model_repo,
398
  general_thresh,
399
  general_mcut_enabled,
400
  character_thresh,
401
  character_mcut_enabled,
402
+ characters_merge_enabled,
403
+ additional_tags_prepend,
404
+ additional_tags_append,
405
  ],
406
  variant="secondary",
407
  size="lg",
408
  )
 
409
  with gr.Column(variant="panel"):
410
+ download_file = gr.File(label="Output (Download)")
411
+ sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True)
412
  rating = gr.Label(label="Rating")
413
  character_res = gr.Label(label="Output (characters)")
414
  general_res = gr.Label(label="Output (tags)")
415
  clear.add(
416
  [
417
+ download_file,
418
  sorted_general_strings,
419
  rating,
420
  character_res,
 
422
  ]
423
  )
424
 
425
+ # When the upload button is clicked, add the new images to the gallery
426
+ upload_button.upload(add_images_to_gallery, inputs=[gallery, upload_button], outputs=gallery)
427
+ # Event to update the selected image when an image is clicked in the gallery
428
+ selected_image = gr.Textbox(label="Selected Image", visible=False)
429
+ gallery.select(get_selection_from_gallery, inputs=gallery, outputs=[selected_image, sorted_general_strings, rating, character_res, general_res])
430
+ # Event to remove a selected image from the gallery
431
+ remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
432
+
433
  submit.click(
434
  predictor.predict,
435
  inputs=[
436
+ gallery,
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  model_repo,
438
  general_thresh,
439
  general_mcut_enabled,
440
  character_thresh,
441
  character_mcut_enabled,
442
+ characters_merge_enabled,
443
+ additional_tags_prepend,
444
+ additional_tags_append,
445
  ],
446
+ outputs=[download_file, sorted_general_strings, rating, character_res, general_res],
447
  )
448
+
449
+ # gr.Examples(
450
+ # [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
451
+ # inputs=[
452
+ # gallery,
453
+ # model_repo,
454
+ # general_thresh,
455
+ # general_mcut_enabled,
456
+ # character_thresh,
457
+ # character_mcut_enabled,
458
+ # characters_merge_enabled,
459
+ # additional_tags_prepend,
460
+ # additional_tags_append,
461
+ # ],
462
+ # )
463
 
464
  demo.queue(max_size=10)
465
+ demo.launch(inbrowser=True)
466
 
467
 
468
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -2,5 +2,5 @@ pillow>=9.0.0
2
  onnxruntime>=1.12.0
3
  huggingface-hub
4
 
5
- gradio==5.4.0
6
  pandas
 
2
  onnxruntime>=1.12.0
3
  huggingface-hub
4
 
5
+ gradio==5.6.0
6
  pandas
webui.bat ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ :: The source of the webui.bat file is stable-diffusion-webui
4
+ :: set COMMANDLINE_ARGS=--whisper_implementation faster-whisper --input_audio_max_duration -1 --default_model_name large-v2 --auto_parallel True --output_dir output --vad_max_merge_size 90 --save_downloaded_files --autolaunch
5
+
6
+ if not defined PYTHON (set PYTHON=python)
7
+ if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
8
+
9
+ mkdir tmp 2>NUL
10
+
11
+ %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
12
+ if %ERRORLEVEL% == 0 goto :check_pip
13
+ echo Couldn't launch python
14
+ goto :show_stdout_stderr
15
+
16
+ :check_pip
17
+ %PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
18
+ if %ERRORLEVEL% == 0 goto :start_venv
19
+ if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
20
+ %PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
21
+ if %ERRORLEVEL% == 0 goto :start_venv
22
+ echo Couldn't install pip
23
+ goto :show_stdout_stderr
24
+
25
+ :start_venv
26
+ if ["%VENV_DIR%"] == ["-"] goto :skip_venv
27
+ if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
28
+
29
+ dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
30
+ if %ERRORLEVEL% == 0 goto :activate_venv
31
+
32
+ for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
33
+ echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
34
+ %PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
35
+ if %ERRORLEVEL% == 0 goto :activate_venv
36
+ echo Unable to create venv in directory "%VENV_DIR%"
37
+ goto :show_stdout_stderr
38
+
39
+ :activate_venv
40
+ set PYTHON="%VENV_DIR%\Scripts\Python.exe"
41
+ echo venv %PYTHON%
42
+
43
+ :skip_venv
44
+ goto :launch
45
+
46
+ :launch
47
+ %PYTHON% app.py %COMMANDLINE_ARGS% %*
48
+ pause
49
+ exit /b
50
+
51
+ :show_stdout_stderr
52
+
53
+ echo.
54
+ echo exit code: %errorlevel%
55
+
56
+ for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
57
+ if %size% equ 0 goto :show_stderr
58
+ echo.
59
+ echo stdout:
60
+ type tmp\stdout.txt
61
+
62
+ :show_stderr
63
+ for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
64
+ if %size% equ 0 goto :show_stderr
65
+ echo.
66
+ echo stderr:
67
+ type tmp\stderr.txt
68
+
69
+ :endofscript
70
+
71
+ echo.
72
+ echo Launch unsuccessful. Exiting.
73
+ pause