GenAI-Arena / model /pre_download.py
vinesmsuic's picture
catch error
b3b25b3
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS
def pre_download_all_models(include_video=True):
"""
Pre-download all models to avoid download delay during the first user request
"""
imagen_dl_error = pre_download_image_models_gen()
imagedit_dl_error = pre_download_image_models_edit()
if include_video:
videogen_dl_error = pre_download_video_models_gen()
else:
videogen_dl_error = ["None"]
print("All models downloaded.")
print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error)
def pre_download_image_models_gen():
"""
Pre-download image models to avoid download delay during the first user request
"""
import imagen_hub
errored_models = []
for model_string in IMAGE_GENERATION_MODELS:
model_lib, model_name, model_type = model_string.split("_")
if model_lib == "imagenhub":
try:
print("Loading image generation model:", model_name)
temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files
del temp_model
except Exception as e:
print(f"Failed to load model {model_name} \n {e}")
errored_models.append(model_string)
continue
else:
pass
return errored_models
def pre_download_image_models_edit():
"""
Pre-download image models to avoid download delay during the first user request
"""
import imagen_hub
errored_models = []
for model_string in IMAGE_EDITION_MODELS:
model_lib, model_name, model_type = model_string.split("_")
if model_lib == "imagenhub":
try:
print("Loading image edition model:", model_name)
temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files
del temp_model
except Exception as e:
print(f"Failed to load model {model_name} \n {e}")
errored_models.append(model_string)
continue
else:
pass
return errored_models
def pre_download_video_models_gen():
"""
Pre-download video models to avoid download delay during the first user request
"""
import videogen_hub
errored_models = []
for model_string in VIDEO_GENERATION_MODELS:
model_lib, model_name, model_type = model_string.split("_")
if model_lib == "videogenhub":
try:
print("Loading video generation model:", model_name)
temp_model = videogen_hub.get_model(model_name) # Forcing model to download weight files
del temp_model
except Exception as e:
print(f"Failed to load model {model_name} \n {e}")
errored_models.append(model_string)
continue
else:
pass
return errored_models