File size: 3,103 Bytes
8fc3c7f
 
4a07334
8fc3c7f
4d8824e
8fc3c7f
4a07334
 
 
 
b3b25b3
 
8fc3c7f
4d8824e
8fc3c7f
4a07334
8fc3c7f
 
 
 
 
 
c271014
8fc3c7f
374537f
8fc3c7f
 
374537f
8fc3c7f
 
 
 
 
 
 
 
 
 
4a07334
8fc3c7f
 
 
 
 
 
c271014
8fc3c7f
 
 
 
374537f
8fc3c7f
 
 
 
 
 
 
 
 
 
4a07334
8fc3c7f
 
 
 
 
 
c271014
8fc3c7f
 
 
 
374537f
8fc3c7f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS

def pre_download_all_models(include_video=True):
    """
    Pre-download all models to avoid download delay during the first user request
    """
    imagen_dl_error = pre_download_image_models_gen()
    imagedit_dl_error = pre_download_image_models_edit()
    if include_video:
        videogen_dl_error = pre_download_video_models_gen()
    else:
        videogen_dl_error = ["None"]
    print("All models downloaded.")
    print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error)

def pre_download_image_models_gen():
    """
    Pre-download image models to avoid download delay during the first user request
    """
    import imagen_hub
    errored_models = []
    for model_string in IMAGE_GENERATION_MODELS:
        
        model_lib, model_name, model_type = model_string.split("_")
        
        if model_lib == "imagenhub":
            try:
                print("Loading image generation model:", model_name)
                temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files
                del temp_model
            except Exception as e:
                print(f"Failed to load model {model_name} \n {e}")
                errored_models.append(model_string)
                continue
        else:
            pass
    return errored_models

def pre_download_image_models_edit():
    """
    Pre-download image models to avoid download delay during the first user request
    """
    import imagen_hub
    errored_models = []
    for model_string in IMAGE_EDITION_MODELS:
        
        model_lib, model_name, model_type = model_string.split("_")

        if model_lib == "imagenhub":
            try:
                print("Loading image edition model:", model_name)
                temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files
                del temp_model
            except Exception as e:
                print(f"Failed to load model {model_name} \n {e}")
                errored_models.append(model_string)
                continue
        else:
            pass
    return errored_models

def pre_download_video_models_gen():
    """
    Pre-download video models to avoid download delay during the first user request
    """
    import videogen_hub
    errored_models = []
    for model_string in VIDEO_GENERATION_MODELS:
        
        model_lib, model_name, model_type = model_string.split("_")

        if model_lib == "videogenhub":
            try:
                print("Loading video generation model:", model_name)
                temp_model = videogen_hub.get_model(model_name) # Forcing model to download weight files
                del temp_model
            except Exception as e:
                print(f"Failed to load model {model_name} \n {e}")
                errored_models.append(model_string)
                continue
        else:
            pass
    return errored_models