from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS def pre_download_all_models(): """ Pre-download all models to avoid download delay during the first user request """ imagen_dl_error = pre_download_image_models() imagedit_dl_error = pre_download_image_models() videogen_dl_error = pre_download_video_models() print("All models downloaded.") print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error) def pre_download_image_models(): """ Pre-download image models to avoid download delay during the first user request """ import imagen_hub errored_models = [] for model_string in IMAGE_GENERATION_MODELS: print("Loading image generation model:", model_name) model_lib, model_name, model_type = model_string.split("_") if model_lib == "imagenhub": try: temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files del temp_model except Exception as e: print(f"Failed to load model {model_name} \n {e}") errored_models.append(model_string) continue else: pass return errored_models def pre_download_image_models(): """ Pre-download image models to avoid download delay during the first user request """ import imagen_hub errored_models = [] for model_string in IMAGE_EDITION_MODELS: print("Loading image edition model:", model_name) model_lib, model_name, model_type = model_string.split("_") if model_lib == "imagenhub": try: temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files del temp_model except Exception as e: print(f"Failed to load model {model_name} \n {e}") errored_models.append(model_string) continue else: pass return errored_models def pre_download_video_models(): """ Pre-download video models to avoid download delay during the first user request """ import videogen_hub errored_models = [] for model_string in VIDEO_GENERATION_MODELS: print("Loading video generation model:", model_name) model_lib, model_name, model_type = model_string.split("_") if model_lib == "videogenhub": try: temp_model = videogen_hub.get_model(model_name) # Forcing model to download weight files del temp_model except Exception as e: print(f"Failed to load model {model_name} \n {e}") errored_models.append(model_string) continue else: pass return errored_models