File size: 4,591 Bytes
bd0bce1
 
944dd2b
26dad4e
0b4b1e4
bd0bce1
94bd22c
9d5a8fc
 
 
77d3539
2707b9d
09a289b
4d32483
bd0bce1
222ecc0
 
172a089
9d5a8fc
8c9a9e8
3f5b1a3
1599f4c
3f5b1a3
e548ada
0517dbe
3f5b1a3
 
 
0517dbe
bf853bd
 
e548ada
aab9d8e
3840d4b
e548ada
3840d4b
9bf990d
bd0bce1
0b4b1e4
 
 
 
 
 
bd0bce1
 
 
 
 
 
 
 
 
f969b11
 
89adc00
bd0bce1
aab9d8e
 
 
bd0bce1
 
 
513e628
944dd2b
26dad4e
aab9d8e
 
 
26dad4e
bd0bce1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from .imagenhub_models import load_imagenhub_model
from .playground_api import load_playground_model
from .fal_api_models import load_fal_model
from .videogenhub_models import load_videogenhub_model
from ..model_registry import model_info


# IMAGE_GENERATION_MODELS = ['fal_LCM(v1.5/XL)_text2image','fal_SDXLTurbo_text2image','fal_SDXL_text2image', 'imagenhub_PixArtAlpha_generation', 'fal_PixArtSigma_text2image',
#                             'imagenhub_OpenJourney_generation','fal_SDXLLightning_text2image', 'fal_StableCascade_text2image',
#                             'playground_PlayGroundV2_generation', 'playground_PlayGroundV2.5_generation']
IMAGE_GENERATION_MODELS = ['imagenhub_SDXLTurbo_generation','imagenhub_SDXL_generation', 'imagenhub_PixArtAlpha_generation', 'imagenhub_PixArtSigma_generation',
                           'imagenhub_OpenJourney_generation','imagenhub_SDXLLightning_generation', 'imagenhub_StableCascade_generation', 'imagenhub_HunyuanDiT_generation',
                           'playground_PlayGroundV2.5_generation', 'imagenhub_Kolors_generation', 'imagenhub_SD3_generation',
                           'fal_AuraFlow_text2image', 'fal_FLUX1schnell_text2image', 'fal_FLUX1dev_text2image']  # 'playground_PlayGroundV2_generation'
IMAGE_EDITION_MODELS = ['imagenhub_CycleDiffusion_edition', 'imagenhub_Pix2PixZero_edition', 'imagenhub_Prompt2prompt_edition',
                        'imagenhub_SDEdit_edition', 'imagenhub_InstructPix2Pix_edition', 
                        'imagenhub_MagicBrush_edition', 'imagenhub_PNP_edition', 
                        'imagenhub_InfEdit_edition', 'imagenhub_CosXLEdit_edition', 'imagenhub_UltraEdit_edition']
VIDEO_GENERATION_MODELS = ['fal_AnimateDiff_text2video',
                           'fal_AnimateDiffTurbo_text2video',
                           #'videogenhub_LaVie_generation', 
                           'videogenhub_VideoCrafter2_generation',
                           #'videogenhub_ModelScope_generation', 
                           'videogenhub_CogVideoX-2B_generation', 
                           'videogenhub_OpenSora12_generation',
                           #'videogenhub_OpenSora_generation', 
                           #'videogenhub_T2VTurbo_generation',
                           'fal_T2VTurbo_text2video',
                           'fal_StableVideoDiffusion_text2video',
                           'fal_CogVideoX-5B_text2video',
                           'videogenhub_PyramidFlow_text2video']

MAP_NAMES_IMAGENHUB = {}
MAP_NAMES_VIDEOGENHUB = {"CogVideoX-2B": "CogVideoX", "CogVideoX-5B": "CogVideoX5B"}

MUSEUM_UNSUPPORTED_MODELS = []
DESIRED_APPEAR_MODEL = ['videogenhub_T2VTurbo_generation','fal_StableVideoDiffusion_text2video']

ALL_MODELS = IMAGE_GENERATION_MODELS + IMAGE_EDITION_MODELS + VIDEO_GENERATION_MODELS

missing_models = [model for model in ALL_MODELS if model not in model_info]
if missing_models:
    raise ValueError(f"Missing models in model_info: {missing_models}")

def load_pipeline(model_name):
    """
    Load a model pipeline based on the model name
    Args:
        model_name (str): The name of the model to load, should be of the form {source}_{name}_{type}
        the source can be either imagenhub or playground
        the name is the name of the model used to load the model
        the type is the type of the model, either generation or edition
    """
    model_source, model_name, model_type = model_name.split("_")
    # TODO ^Support model_name that contains _ in the name
    print("model_source, model_name, model_type =", model_source, model_name, model_type)
    if model_source == "imagenhub":
        # This approach uses get() to try to fetch the value associated with model_name. If model_name isn't found in the dictionary, 
        # it defaults to the current value of model_name,
        model_name = MAP_NAMES_IMAGENHUB.get(model_name, model_name) 
        pipe = load_imagenhub_model(model_name, model_type)
    elif model_source == "playground":
        pipe = load_playground_model(model_name)
    elif model_source == "fal":
        pipe = load_fal_model(model_name, model_type)
    elif model_source == "videogenhub":
        # This approach uses get() to try to fetch the value associated with model_name. If model_name isn't found in the dictionary, 
        # it defaults to the current value of model_name,
        model_name = MAP_NAMES_VIDEOGENHUB.get(model_name, model_name) 
        pipe = load_videogenhub_model(model_name)
    else:
        raise ValueError(f"Model source {model_source} not supported")
    return pipe