SmilingWolf commited on
Commit
b079c7b
·
1 Parent(s): 4148c50

Update app.py

Browse files

Add newly released ConvNextV2 model

Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -21,6 +21,7 @@ DESCRIPTION = """
21
  Demo for:
22
  - [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
23
  - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
 
24
  - [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
25
 
26
  Includes "ready to copy" prompt and a prompt analyzer.
@@ -36,6 +37,7 @@ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
36
  HF_TOKEN = os.environ["HF_TOKEN"]
37
  SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
38
  CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
 
39
  VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
40
  MODEL_FILENAME = "model.onnx"
41
  LABEL_FILENAME = "selected_tags.csv"
@@ -65,6 +67,8 @@ def change_model(model_name):
65
  model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
66
  elif model_name == "ConvNext":
67
  model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
 
 
68
  elif model_name == "ViT":
69
  model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
70
 
@@ -74,7 +78,7 @@ def change_model(model_name):
74
 
75
  def load_labels() -> list[str]:
76
  path = huggingface_hub.hf_hub_download(
77
- SWIN_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
78
  )
79
  df = pd.read_csv(path)
80
 
@@ -209,11 +213,11 @@ def predict(
209
 
210
  def main():
211
  global loaded_models
212
- loaded_models = {"SwinV2": None, "ConvNext": None, "ViT": None}
213
 
214
  args = parse_args()
215
 
216
- change_model("SwinV2")
217
 
218
  tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
219
 
@@ -229,7 +233,7 @@ def main():
229
  fn=func,
230
  inputs=[
231
  gr.Image(type="pil", label="Input"),
232
- gr.Radio(["SwinV2", "ConvNext", "ViT"], value="SwinV2", label="Model"),
233
  gr.Slider(
234
  0,
235
  1,
@@ -253,7 +257,7 @@ def main():
253
  gr.Label(label="Output (tags)"),
254
  gr.HTML(),
255
  ],
256
- examples=[["power.jpg", "SwinV2", 0.35, 0.85]],
257
  title=TITLE,
258
  description=DESCRIPTION,
259
  allow_flagging="never",
 
21
  Demo for:
22
  - [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
23
  - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
24
+ - [SmilingWolf/wd-v1-4-convnextv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2)
25
  - [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
26
 
27
  Includes "ready to copy" prompt and a prompt analyzer.
 
37
  HF_TOKEN = os.environ["HF_TOKEN"]
38
  SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
39
  CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
40
+ CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
41
  VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
42
  MODEL_FILENAME = "model.onnx"
43
  LABEL_FILENAME = "selected_tags.csv"
 
67
  model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
68
  elif model_name == "ConvNext":
69
  model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
70
+ elif model_name == "ConvNextV2":
71
+ model = load_model(CONV2_MODEL_REPO, MODEL_FILENAME)
72
  elif model_name == "ViT":
73
  model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
74
 
 
78
 
79
  def load_labels() -> list[str]:
80
  path = huggingface_hub.hf_hub_download(
81
+ CONV2_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
82
  )
83
  df = pd.read_csv(path)
84
 
 
213
 
214
  def main():
215
  global loaded_models
216
+ loaded_models = {"SwinV2": None, "ConvNext": None, "ConvNextV2": None, "ViT": None}
217
 
218
  args = parse_args()
219
 
220
+ change_model("ConvNextV2")
221
 
222
  tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
223
 
 
233
  fn=func,
234
  inputs=[
235
  gr.Image(type="pil", label="Input"),
236
+ gr.Radio(["SwinV2", "ConvNext", "ConvNextV2", "ViT"], value="ConvNextV2", label="Model"),
237
  gr.Slider(
238
  0,
239
  1,
 
257
  gr.Label(label="Output (tags)"),
258
  gr.HTML(),
259
  ],
260
+ examples=[["power.jpg", "ConvNextV2", 0.35, 0.85]],
261
  title=TITLE,
262
  description=DESCRIPTION,
263
  allow_flagging="never",