Spaces:
Running
on
Zero
Running
on
Zero
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS | |
def pre_download_all_models(include_video=True): | |
""" | |
Pre-download all models to avoid download delay during the first user request | |
""" | |
imagen_dl_error = pre_download_image_models_gen() | |
imagedit_dl_error = pre_download_image_models_edit() | |
if include_video: | |
videogen_dl_error = pre_download_video_models_gen() | |
else: | |
videogen_dl_error = ["None"] | |
print("All models downloaded.") | |
print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error) | |
def pre_download_image_models_gen(): | |
""" | |
Pre-download image models to avoid download delay during the first user request | |
""" | |
import imagen_hub | |
errored_models = [] | |
for model_string in IMAGE_GENERATION_MODELS: | |
model_lib, model_name, model_type = model_string.split("_") | |
if model_lib == "imagenhub": | |
try: | |
print("Loading image generation model:", model_name) | |
temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files | |
del temp_model | |
except Exception as e: | |
print(f"Failed to load model {model_name} \n {e}") | |
errored_models.append(model_string) | |
continue | |
else: | |
pass | |
return errored_models | |
def pre_download_image_models_edit(): | |
""" | |
Pre-download image models to avoid download delay during the first user request | |
""" | |
import imagen_hub | |
errored_models = [] | |
for model_string in IMAGE_EDITION_MODELS: | |
model_lib, model_name, model_type = model_string.split("_") | |
if model_lib == "imagenhub": | |
try: | |
print("Loading image edition model:", model_name) | |
temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files | |
del temp_model | |
except Exception as e: | |
print(f"Failed to load model {model_name} \n {e}") | |
errored_models.append(model_string) | |
continue | |
else: | |
pass | |
return errored_models | |
def pre_download_video_models_gen(): | |
""" | |
Pre-download video models to avoid download delay during the first user request | |
""" | |
import videogen_hub | |
errored_models = [] | |
for model_string in VIDEO_GENERATION_MODELS: | |
model_lib, model_name, model_type = model_string.split("_") | |
if model_lib == "videogenhub": | |
try: | |
print("Loading video generation model:", model_name) | |
temp_model = videogen_hub.get_model(model_name) # Forcing model to download weight files | |
del temp_model | |
except Exception as e: | |
print(f"Failed to load model {model_name} \n {e}") | |
errored_models.append(model_string) | |
continue | |
else: | |
pass | |
return errored_models |