from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS def pre_download_all_models(include_video=True): """ Pre-download all models to avoid download delay during the first user request """ imagen_dl_error = pre_download_image_models_gen() imagedit_dl_error = pre_download_image_models_edit() if include_video: videogen_dl_error = pre_download_video_models_gen() else: videogen_dl_error = ["None"] print("All models downloaded.") print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error) def pre_download_image_models_gen(): """ Pre-download image models to avoid download delay during the first user request """ import imagen_hub errored_models = [] for model_string in IMAGE_GENERATION_MODELS: model_lib, model_name, model_type = model_string.split("_") if model_lib == "imagenhub": try: print("Loading image generation model:", model_name) temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files del temp_model except Exception as e: print(f"Failed to load model {model_name} \n {e}") errored_models.append(model_string) continue else: pass return errored_models def pre_download_image_models_edit(): """ Pre-download image models to avoid download delay during the first user request """ import imagen_hub errored_models = [] for model_string in IMAGE_EDITION_MODELS: model_lib, model_name, model_type = model_string.split("_") if model_lib == "imagenhub": try: print("Loading image edition model:", model_name) temp_model = imagen_hub.get_model(model_name) # Forcing model to download weight files del temp_model except Exception as e: print(f"Failed to load model {model_name} \n {e}") errored_models.append(model_string) continue else: pass return errored_models def pre_download_video_models_gen(): """ Pre-download video models to avoid download delay during the first user request """ import videogen_hub errored_models = [] for model_string in VIDEO_GENERATION_MODELS: model_lib, model_name, model_type = model_string.split("_") if model_lib == "videogenhub": try: print("Loading video generation model:", model_name) temp_model = videogen_hub.get_model(model_name) # Forcing model to download weight files del temp_model except Exception as e: print(f"Failed to load model {model_name} \n {e}") errored_models.append(model_string) continue else: pass return errored_models