Spaces:
Running
on
Zero
Running
on
Zero
from .imagenhub_models import load_imagenhub_model | |
from .playground_api import load_playground_model | |
from .fal_api_models import load_fal_model | |
from .videogenhub_models import load_videogenhub_model | |
from ..model_registry import model_info | |
# IMAGE_GENERATION_MODELS = ['fal_LCM(v1.5/XL)_text2image','fal_SDXLTurbo_text2image','fal_SDXL_text2image', 'imagenhub_PixArtAlpha_generation', 'fal_PixArtSigma_text2image', | |
# 'imagenhub_OpenJourney_generation','fal_SDXLLightning_text2image', 'fal_StableCascade_text2image', | |
# 'playground_PlayGroundV2_generation', 'playground_PlayGroundV2.5_generation'] | |
IMAGE_GENERATION_MODELS = ['imagenhub_SDXLTurbo_generation','imagenhub_SDXL_generation', 'imagenhub_PixArtAlpha_generation', 'imagenhub_PixArtSigma_generation', | |
'imagenhub_OpenJourney_generation','imagenhub_SDXLLightning_generation', 'imagenhub_StableCascade_generation', 'imagenhub_HunyuanDiT_generation', | |
'playground_PlayGroundV2.5_generation', 'imagenhub_Kolors_generation', 'imagenhub_SD3_generation', | |
'fal_AuraFlow_text2image', 'fal_FLUX1schnell_text2image', 'fal_FLUX1dev_text2image'] # 'playground_PlayGroundV2_generation' | |
IMAGE_EDITION_MODELS = ['imagenhub_CycleDiffusion_edition', 'imagenhub_Pix2PixZero_edition', 'imagenhub_Prompt2prompt_edition', | |
'imagenhub_SDEdit_edition', 'imagenhub_InstructPix2Pix_edition', | |
'imagenhub_MagicBrush_edition', 'imagenhub_PNP_edition', | |
'imagenhub_InfEdit_edition', 'imagenhub_CosXLEdit_edition', 'imagenhub_UltraEdit_edition'] | |
VIDEO_GENERATION_MODELS = ['fal_AnimateDiff_text2video', | |
'fal_AnimateDiffTurbo_text2video', | |
#'videogenhub_LaVie_generation', | |
'videogenhub_VideoCrafter2_generation', | |
#'videogenhub_ModelScope_generation', | |
'videogenhub_CogVideoX-2B_generation', | |
'videogenhub_OpenSora12_generation', | |
#'videogenhub_OpenSora_generation', | |
#'videogenhub_T2VTurbo_generation', | |
'fal_T2VTurbo_text2video', | |
'fal_StableVideoDiffusion_text2video', | |
'fal_CogVideoX-5B_text2video', | |
'videogenhub_PyramidFlow_text2video'] | |
MAP_NAMES_IMAGENHUB = {} | |
MAP_NAMES_VIDEOGENHUB = {"CogVideoX-2B": "CogVideoX", "CogVideoX-5B": "CogVideoX5B"} | |
MUSEUM_UNSUPPORTED_MODELS = [] | |
DESIRED_APPEAR_MODEL = ['videogenhub_T2VTurbo_generation','fal_StableVideoDiffusion_text2video'] | |
ALL_MODELS = IMAGE_GENERATION_MODELS + IMAGE_EDITION_MODELS + VIDEO_GENERATION_MODELS | |
missing_models = [model for model in ALL_MODELS if model not in model_info] | |
if missing_models: | |
raise ValueError(f"Missing models in model_info: {missing_models}") | |
def load_pipeline(model_name): | |
""" | |
Load a model pipeline based on the model name | |
Args: | |
model_name (str): The name of the model to load, should be of the form {source}_{name}_{type} | |
the source can be either imagenhub or playground | |
the name is the name of the model used to load the model | |
the type is the type of the model, either generation or edition | |
""" | |
model_source, model_name, model_type = model_name.split("_") | |
# TODO ^Support model_name that contains _ in the name | |
print("model_source, model_name, model_type =", model_source, model_name, model_type) | |
if model_source == "imagenhub": | |
# This approach uses get() to try to fetch the value associated with model_name. If model_name isn't found in the dictionary, | |
# it defaults to the current value of model_name, | |
model_name = MAP_NAMES_IMAGENHUB.get(model_name, model_name) | |
pipe = load_imagenhub_model(model_name, model_type) | |
elif model_source == "playground": | |
pipe = load_playground_model(model_name) | |
elif model_source == "fal": | |
pipe = load_fal_model(model_name, model_type) | |
elif model_source == "videogenhub": | |
# This approach uses get() to try to fetch the value associated with model_name. If model_name isn't found in the dictionary, | |
# it defaults to the current value of model_name, | |
model_name = MAP_NAMES_VIDEOGENHUB.get(model_name, model_name) | |
pipe = load_videogenhub_model(model_name) | |
else: | |
raise ValueError(f"Model source {model_source} not supported") | |
return pipe |