vinesmsuic commited on
Commit
09a289b
1 Parent(s): 17cde38

New models

Browse files
arena_elo/elo_rating/clean_battle_data.py CHANGED
@@ -58,6 +58,10 @@ IDENTITY_WORDS = [
58
  for i in range(len(IDENTITY_WORDS)):
59
  IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower()
60
 
 
 
 
 
61
 
62
  def remove_html(raw):
63
  if raw.startswith("<h3>"):
@@ -229,7 +233,8 @@ def clean_battle_data(
229
  valid = True
230
  for _model in models:
231
  try:
232
- platform, model_name, task = _model.split("_")
 
233
  except ValueError:
234
  valid = False
235
  break
@@ -240,7 +245,8 @@ def clean_battle_data(
240
  ct_invalid += 1
241
  continue
242
  for i, _model in enumerate(models):
243
- platform, model_name, task = _model.split("_")
 
244
  models[i] = model_name
245
 
246
  # if not all(x.startswith("imagenhub_") and x.endswith("_edition") for x in models):
@@ -253,7 +259,8 @@ def clean_battle_data(
253
  valid = True
254
  for _model in models:
255
  try:
256
- platform, model_name, task = _model.split("_")
 
257
  except ValueError:
258
  valid = False
259
  break
@@ -264,7 +271,8 @@ def clean_battle_data(
264
  ct_invalid += 1
265
  continue
266
  for i, _model in enumerate(models):
267
- platform, model_name, task = _model.split("_")
 
268
  models[i] = model_name
269
  # if not all("playground" in x.lower() or (x.startswith("imagenhub_") and x.endswith("_generation")) for x in models):
270
  # print(f"Invalid model names: {models}")
@@ -280,7 +288,8 @@ def clean_battle_data(
280
  valid = True
281
  for _model in models:
282
  try:
283
- platform, model_name, task = _model.split("_")
 
284
  except ValueError:
285
  valid = False
286
  break
@@ -291,7 +300,8 @@ def clean_battle_data(
291
  ct_invalid += 1
292
  continue
293
  for i, _model in enumerate(models):
294
- platform, model_name, task = _model.split("_")
 
295
  models[i] = model_name
296
 
297
  else:
 
58
  for i in range(len(IDENTITY_WORDS)):
59
  IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower()
60
 
61
+ def parse_model_name(model_name):
62
+ model_source, *rest = model_name.split("_", 1)
63
+ model_type, model_name = rest[-1], "_".join(rest[:-1])
64
+ return model_source, model_name, model_type
65
 
66
  def remove_html(raw):
67
  if raw.startswith("<h3>"):
 
233
  valid = True
234
  for _model in models:
235
  try:
236
+ #platform, model_name, task = _model.split("_")
237
+ platform, model_name, task = parse_model_name(_model)
238
  except ValueError:
239
  valid = False
240
  break
 
245
  ct_invalid += 1
246
  continue
247
  for i, _model in enumerate(models):
248
+ #platform, model_name, task = _model.split("_")
249
+ platform, model_name, task = parse_model_name(_model)
250
  models[i] = model_name
251
 
252
  # if not all(x.startswith("imagenhub_") and x.endswith("_edition") for x in models):
 
259
  valid = True
260
  for _model in models:
261
  try:
262
+ #platform, model_name, task = _model.split("_")
263
+ platform, model_name, task = parse_model_name(_model)
264
  except ValueError:
265
  valid = False
266
  break
 
271
  ct_invalid += 1
272
  continue
273
  for i, _model in enumerate(models):
274
+ #platform, model_name, task = _model.split("_")
275
+ platform, model_name, task = parse_model_name(_model)
276
  models[i] = model_name
277
  # if not all("playground" in x.lower() or (x.startswith("imagenhub_") and x.endswith("_generation")) for x in models):
278
  # print(f"Invalid model names: {models}")
 
288
  valid = True
289
  for _model in models:
290
  try:
291
+ #platform, model_name, task = _model.split("_")
292
+ platform, model_name, task = parse_model_name(_model)
293
  except ValueError:
294
  valid = False
295
  break
 
300
  ct_invalid += 1
301
  continue
302
  for i, _model in enumerate(models):
303
+ #platform, model_name, task = _model.split("_")
304
+ platform, model_name, task = parse_model_name(_model)
305
  models[i] = model_name
306
 
307
  else:
model/model_registry.py CHANGED
@@ -143,6 +143,28 @@ register_model_info(
143
  "Kolors is a large-scale text-to-image generation model based on latent diffusion",
144
  )
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  # regist image edition models
147
  register_model_info(
148
  ["imagenhub_CycleDiffusion_edition"],
 
143
  "Kolors is a large-scale text-to-image generation model based on latent diffusion",
144
  )
145
 
146
+ register_model_info(
147
+ ["fal_AuraFlow_text2image"],
148
+ "AuraFlow",
149
+ "https://huggingface.co/fal/AuraFlow",
150
+ "Opensourced flow-based text-to-image generation model.",
151
+ )
152
+
153
+ register_model_info(
154
+ ["fal_FluxTimestep_text2image"],
155
+ "FLUX.1-schnell",
156
+ "https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux",
157
+ "Flux is a series of text-to-image generation models based on diffusion transformers. Timestep-distilled version.",
158
+ )
159
+
160
+ register_model_info(
161
+ ["fal_FluxGuidance_text2image"],
162
+ "FLUX.1-dev",
163
+ "https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux",
164
+ "Flux is a series of text-to-image generation models based on diffusion transformers. Guidance-distilled version.",
165
+ )
166
+
167
+
168
  # regist image edition models
169
  register_model_info(
170
  ["imagenhub_CycleDiffusion_edition"],
model/models/__init__.py CHANGED
@@ -9,7 +9,8 @@ from .videogenhub_models import load_videogenhub_model
9
  # 'playground_PlayGroundV2_generation', 'playground_PlayGroundV2.5_generation']
10
  IMAGE_GENERATION_MODELS = ['imagenhub_SDXLTurbo_generation','imagenhub_SDXL_generation', 'imagenhub_PixArtAlpha_generation', 'imagenhub_PixArtSigma_generation',
11
  'imagenhub_OpenJourney_generation','imagenhub_SDXLLightning_generation', 'imagenhub_StableCascade_generation', 'imagenhub_HunyuanDiT_generation',
12
- 'playground_PlayGroundV2.5_generation', 'imagenhub_Kolors_generation', 'imagenhub_SD3_generation'] # 'playground_PlayGroundV2_generation'
 
13
  IMAGE_EDITION_MODELS = ['imagenhub_CycleDiffusion_edition', 'imagenhub_Pix2PixZero_edition', 'imagenhub_Prompt2prompt_edition',
14
  'imagenhub_SDEdit_edition', 'imagenhub_InstructPix2Pix_edition',
15
  'imagenhub_MagicBrush_edition', 'imagenhub_PNP_edition',
@@ -32,7 +33,9 @@ def load_pipeline(model_name):
32
  the name is the name of the model used to load the model
33
  the type is the type of the model, either generation or edition
34
  """
35
- model_source, model_name, model_type = model_name.split("_")
 
 
36
  if model_source == "imagenhub":
37
  pipe = load_imagenhub_model(model_name, model_type)
38
  elif model_source == "playground":
 
9
  # 'playground_PlayGroundV2_generation', 'playground_PlayGroundV2.5_generation']
10
  IMAGE_GENERATION_MODELS = ['imagenhub_SDXLTurbo_generation','imagenhub_SDXL_generation', 'imagenhub_PixArtAlpha_generation', 'imagenhub_PixArtSigma_generation',
11
  'imagenhub_OpenJourney_generation','imagenhub_SDXLLightning_generation', 'imagenhub_StableCascade_generation', 'imagenhub_HunyuanDiT_generation',
12
+ 'playground_PlayGroundV2.5_generation', 'imagenhub_Kolors_generation', 'imagenhub_SD3_generation',
13
+ 'fal_AuraFlow_text2image', 'fal_FluxTimestep_text2image', 'fal_FluxGuidance_text2image'] # 'playground_PlayGroundV2_generation'
14
  IMAGE_EDITION_MODELS = ['imagenhub_CycleDiffusion_edition', 'imagenhub_Pix2PixZero_edition', 'imagenhub_Prompt2prompt_edition',
15
  'imagenhub_SDEdit_edition', 'imagenhub_InstructPix2Pix_edition',
16
  'imagenhub_MagicBrush_edition', 'imagenhub_PNP_edition',
 
33
  the name is the name of the model used to load the model
34
  the type is the type of the model, either generation or edition
35
  """
36
+ #model_source, model_name, model_type = model_name.split("_")
37
+ model_source, *rest = model_name.split("_", 1)
38
+ model_type, model_name = rest[-1], "_".join(rest[:-1])
39
  if model_source == "imagenhub":
40
  pipe = load_imagenhub_model(model_name, model_type)
41
  elif model_source == "playground":
model/models/fal_api_models.py CHANGED
@@ -6,7 +6,8 @@ import os
6
  import base64
7
 
8
  FAL_MODEl_NAME_MAP = {"SDXL": "fast-sdxl", "SDXLTurbo": "fast-turbo-diffusion", "SDXLLightning": "fast-lightning-sdxl",
9
- "LCM(v1.5/XL)": "fast-lcm-diffusion", "PixArtSigma": "pixart-sigma", "StableCascade": "stable-cascade"}
 
10
 
11
  class FalModel():
12
  def __init__(self, model_name, model_type):
 
6
  import base64
7
 
8
  FAL_MODEl_NAME_MAP = {"SDXL": "fast-sdxl", "SDXLTurbo": "fast-turbo-diffusion", "SDXLLightning": "fast-lightning-sdxl",
9
+ "LCM(v1.5/XL)": "fast-lcm-diffusion", "PixArtSigma": "pixart-sigma", "StableCascade": "stable-cascade",
10
+ "AuraFlow": "aura-flow", "FluxTimestep": "flux/schnell", "FluxGuidance": "flux/dev"}
11
 
12
  class FalModel():
13
  def __init__(self, model_name, model_type):