Spaces:
Running
on
Zero
Running
on
Zero
vinesmsuic
commited on
Commit
•
4a07334
1
Parent(s):
b3212f3
Fix bug on model preloadling
Browse files- model/model_manager.py +2 -2
- model/pre_download.py +8 -7
model/model_manager.py
CHANGED
@@ -7,7 +7,7 @@ import spaces
|
|
7 |
from PIL import Image
|
8 |
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, MUSEUM_UNSUPPORTED_MODELS, DESIRED_APPEAR_MODEL, load_pipeline
|
9 |
from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum, draw_from_videogen_museum, draw2_from_videogen_museum
|
10 |
-
from .pre_download import pre_download_all_models,
|
11 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
12 |
import torch
|
13 |
import re
|
@@ -82,7 +82,7 @@ class ModelManager:
|
|
82 |
self.load_guard(enable_nsfw)
|
83 |
self.loaded_models = {}
|
84 |
if do_pre_download:
|
85 |
-
pre_download_all_models()
|
86 |
if do_debug_packages:
|
87 |
debug_packages()
|
88 |
|
|
|
7 |
from PIL import Image
|
8 |
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, MUSEUM_UNSUPPORTED_MODELS, DESIRED_APPEAR_MODEL, load_pipeline
|
9 |
from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum, draw_from_videogen_museum, draw2_from_videogen_museum
|
10 |
+
from .pre_download import pre_download_all_models, pre_download_image_models_gen, pre_download_image_models_edit, pre_download_video_models_gen
|
11 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
12 |
import torch
|
13 |
import re
|
|
|
82 |
self.load_guard(enable_nsfw)
|
83 |
self.loaded_models = {}
|
84 |
if do_pre_download:
|
85 |
+
pre_download_all_models(include_video=False)
|
86 |
if do_debug_packages:
|
87 |
debug_packages()
|
88 |
|
model/pre_download.py
CHANGED
@@ -1,16 +1,17 @@
|
|
1 |
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS
|
2 |
|
3 |
-
def pre_download_all_models():
|
4 |
"""
|
5 |
Pre-download all models to avoid download delay during the first user request
|
6 |
"""
|
7 |
-
imagen_dl_error =
|
8 |
-
imagedit_dl_error =
|
9 |
-
|
|
|
10 |
print("All models downloaded.")
|
11 |
print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error)
|
12 |
|
13 |
-
def
|
14 |
"""
|
15 |
Pre-download image models to avoid download delay during the first user request
|
16 |
"""
|
@@ -33,7 +34,7 @@ def pre_download_image_models():
|
|
33 |
pass
|
34 |
return errored_models
|
35 |
|
36 |
-
def
|
37 |
"""
|
38 |
Pre-download image models to avoid download delay during the first user request
|
39 |
"""
|
@@ -56,7 +57,7 @@ def pre_download_image_models():
|
|
56 |
pass
|
57 |
return errored_models
|
58 |
|
59 |
-
def
|
60 |
"""
|
61 |
Pre-download video models to avoid download delay during the first user request
|
62 |
"""
|
|
|
1 |
from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS
|
2 |
|
3 |
+
def pre_download_all_models(include_video=True):
|
4 |
"""
|
5 |
Pre-download all models to avoid download delay during the first user request
|
6 |
"""
|
7 |
+
imagen_dl_error = pre_download_image_models_gen()
|
8 |
+
imagedit_dl_error = pre_download_image_models_edit()
|
9 |
+
if include_video:
|
10 |
+
videogen_dl_error = pre_download_video_models_gen()
|
11 |
print("All models downloaded.")
|
12 |
print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error)
|
13 |
|
14 |
+
def pre_download_image_models_gen():
|
15 |
"""
|
16 |
Pre-download image models to avoid download delay during the first user request
|
17 |
"""
|
|
|
34 |
pass
|
35 |
return errored_models
|
36 |
|
37 |
+
def pre_download_image_models_edit():
|
38 |
"""
|
39 |
Pre-download image models to avoid download delay during the first user request
|
40 |
"""
|
|
|
57 |
pass
|
58 |
return errored_models
|
59 |
|
60 |
+
def pre_download_video_models_gen():
|
61 |
"""
|
62 |
Pre-download video models to avoid download delay during the first user request
|
63 |
"""
|