vinesmsuic commited on
Commit
4a07334
1 Parent(s): b3212f3

Fix bug on model preloadling

Browse files
Files changed (2) hide show
  1. model/model_manager.py +2 -2
  2. model/pre_download.py +8 -7
model/model_manager.py CHANGED
@@ -7,7 +7,7 @@ import spaces
7
  from PIL import Image
8
  from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, MUSEUM_UNSUPPORTED_MODELS, DESIRED_APPEAR_MODEL, load_pipeline
9
  from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum, draw_from_videogen_museum, draw2_from_videogen_museum
10
- from .pre_download import pre_download_all_models, pre_download_video_models
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
  import torch
13
  import re
@@ -82,7 +82,7 @@ class ModelManager:
82
  self.load_guard(enable_nsfw)
83
  self.loaded_models = {}
84
  if do_pre_download:
85
- pre_download_all_models()
86
  if do_debug_packages:
87
  debug_packages()
88
 
 
7
  from PIL import Image
8
  from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS, MUSEUM_UNSUPPORTED_MODELS, DESIRED_APPEAR_MODEL, load_pipeline
9
  from .fetch_museum_results import draw_from_imagen_museum, draw2_from_imagen_museum, draw_from_videogen_museum, draw2_from_videogen_museum
10
+ from .pre_download import pre_download_all_models, pre_download_image_models_gen, pre_download_image_models_edit, pre_download_video_models_gen
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
  import torch
13
  import re
 
82
  self.load_guard(enable_nsfw)
83
  self.loaded_models = {}
84
  if do_pre_download:
85
+ pre_download_all_models(include_video=False)
86
  if do_debug_packages:
87
  debug_packages()
88
 
model/pre_download.py CHANGED
@@ -1,16 +1,17 @@
1
  from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS
2
 
3
- def pre_download_all_models():
4
  """
5
  Pre-download all models to avoid download delay during the first user request
6
  """
7
- imagen_dl_error = pre_download_image_models()
8
- imagedit_dl_error = pre_download_image_models()
9
- videogen_dl_error = pre_download_video_models()
 
10
  print("All models downloaded.")
11
  print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error)
12
 
13
- def pre_download_image_models():
14
  """
15
  Pre-download image models to avoid download delay during the first user request
16
  """
@@ -33,7 +34,7 @@ def pre_download_image_models():
33
  pass
34
  return errored_models
35
 
36
- def pre_download_image_models():
37
  """
38
  Pre-download image models to avoid download delay during the first user request
39
  """
@@ -56,7 +57,7 @@ def pre_download_image_models():
56
  pass
57
  return errored_models
58
 
59
- def pre_download_video_models():
60
  """
61
  Pre-download video models to avoid download delay during the first user request
62
  """
 
1
  from .models import IMAGE_GENERATION_MODELS, IMAGE_EDITION_MODELS, VIDEO_GENERATION_MODELS
2
 
3
+ def pre_download_all_models(include_video=True):
4
  """
5
  Pre-download all models to avoid download delay during the first user request
6
  """
7
+ imagen_dl_error = pre_download_image_models_gen()
8
+ imagedit_dl_error = pre_download_image_models_edit()
9
+ if include_video:
10
+ videogen_dl_error = pre_download_video_models_gen()
11
  print("All models downloaded.")
12
  print("Models that encountered download error:", "Image Generation:", imagen_dl_error, "Image Edition:", imagedit_dl_error, "Video Generation:", videogen_dl_error)
13
 
14
+ def pre_download_image_models_gen():
15
  """
16
  Pre-download image models to avoid download delay during the first user request
17
  """
 
34
  pass
35
  return errored_models
36
 
37
+ def pre_download_image_models_edit():
38
  """
39
  Pre-download image models to avoid download delay during the first user request
40
  """
 
57
  pass
58
  return errored_models
59
 
60
+ def pre_download_video_models_gen():
61
  """
62
  Pre-download video models to avoid download delay during the first user request
63
  """