adamelliotfields commited on
Commit
13b498b
1 Parent(s): 069fc81

Add progress_bar context manager

Browse files
Files changed (5) hide show
  1. lib/__init__.py +3 -2
  2. lib/inference.py +24 -16
  3. lib/loader.py +33 -17
  4. lib/logger.py +0 -18
  5. lib/utils.py +11 -0
lib/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
  from .config import Config
2
  from .inference import generate
3
  from .loader import Loader
4
- from .logger import Logger, log_fn
5
  from .upscaler import RealESRGAN
6
  from .utils import (
7
  async_call,
@@ -10,6 +10,7 @@ from .utils import (
10
  download_repo_files,
11
  enable_progress_bars,
12
  load_json,
 
13
  read_file,
14
  timer,
15
  )
@@ -26,7 +27,7 @@ __all__ = [
26
  "enable_progress_bars",
27
  "generate",
28
  "load_json",
29
- "log_fn",
30
  "read_file",
31
  "timer",
32
  ]
 
1
  from .config import Config
2
  from .inference import generate
3
  from .loader import Loader
4
+ from .logger import Logger
5
  from .upscaler import RealESRGAN
6
  from .utils import (
7
  async_call,
 
10
  download_repo_files,
11
  enable_progress_bars,
12
  load_json,
13
+ progress_bar,
14
  read_file,
15
  timer,
16
  )
 
27
  "enable_progress_bars",
28
  "generate",
29
  "load_json",
30
+ "progress_bar",
31
  "read_file",
32
  "timer",
33
  ]
lib/inference.py CHANGED
@@ -16,7 +16,7 @@ from spaces import GPU
16
  from .config import Config
17
  from .loader import Loader
18
  from .logger import Logger
19
- from .utils import load_json
20
 
21
 
22
  def parse_prompt_with_arrays(prompt: str) -> list[str]:
@@ -193,20 +193,26 @@ def generate(
193
  weights = []
194
  loras_and_weights = [(lora_1, lora_1_weight), (lora_2, lora_2_weight)]
195
  loras_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "loras"))
196
- for lora, weight in loras_and_weights:
197
- if lora and lora.lower() != "none" and lora not in loras:
198
- config = Config.CIVIT_LORAS.get(lora)
199
- if config:
200
- try:
201
- pipe.load_lora_weights(
202
- loras_dir,
203
- adapter_name=lora,
204
- weight_name=f"{lora}.{config['model_version_id']}.safetensors",
205
- )
206
- weights.append(weight)
207
- loras.append(lora)
208
- except Exception:
209
- raise Error(f"Error loading {config['name']} LoRA")
 
 
 
 
 
 
210
 
211
  # unload after generating or if there was an error
212
  try:
@@ -294,7 +300,9 @@ def generate(
294
  try:
295
  image = pipe(**kwargs).images[0]
296
  if scale > 1:
297
- image = upscaler.predict(image)
 
 
298
  images.append((image, str(current_seed)))
299
  current_seed += 1
300
  except Exception as e:
 
16
  from .config import Config
17
  from .loader import Loader
18
  from .logger import Logger
19
+ from .utils import load_json, progress_bar, timer
20
 
21
 
22
  def parse_prompt_with_arrays(prompt: str) -> list[str]:
 
193
  weights = []
194
  loras_and_weights = [(lora_1, lora_1_weight), (lora_2, lora_2_weight)]
195
  loras_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "loras"))
196
+ total_loras = sum(1 for lora, _ in loras_and_weights if lora and lora.lower() != "none")
197
+ desc_loras = "Loading LoRAs"
198
+ if total_loras > 0:
199
+ with timer(f"Loading {total_loras} LoRA{'s' if total_loras > 1 else ''}"):
200
+ progress((0, total_loras), desc=desc_loras)
201
+ for i, (lora, weight) in enumerate(loras_and_weights):
202
+ if lora and lora.lower() != "none" and lora not in loras:
203
+ config = Config.CIVIT_LORAS.get(lora)
204
+ if config:
205
+ try:
206
+ pipe.load_lora_weights(
207
+ loras_dir,
208
+ adapter_name=lora,
209
+ weight_name=f"{lora}.{config['model_version_id']}.safetensors",
210
+ )
211
+ weights.append(weight)
212
+ loras.append(lora)
213
+ progress((i + 1, total_loras), desc=desc_loras)
214
+ except Exception:
215
+ raise Error(f"Error loading {config['name']} LoRA")
216
 
217
  # unload after generating or if there was an error
218
  try:
 
300
  try:
301
  image = pipe(**kwargs).images[0]
302
  if scale > 1:
303
+ msg = f"Upscaling {scale}x"
304
+ with timer(msg, logger=log.info), progress_bar(100, desc=msg, progress=progress):
305
+ image = upscaler.predict(image)
306
  images.append((image, str(current_seed)))
307
  current_seed += 1
308
  except Exception as e:
lib/loader.py CHANGED
@@ -9,7 +9,7 @@ from diffusers.models.attention_processor import AttnProcessor2_0, IPAdapterAttn
9
  from .config import Config
10
  from .logger import Logger
11
  from .upscaler import RealESRGAN
12
- from .utils import timer
13
 
14
 
15
  class Loader:
@@ -27,6 +27,20 @@ class Loader:
27
  cls._instance.log = Logger("Loader")
28
  return cls._instance
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def _should_unload_upscaler(self, scale=1):
31
  if self.upscaler is not None and self.upscaler.scale != scale:
32
  return True
@@ -119,12 +133,15 @@ class Loader:
119
  setattr(self, component, None)
120
  gc.collect()
121
 
122
- def _load_upscaler(self, scale=1):
123
  if self.upscaler is None and scale > 1:
124
  try:
125
- with timer(f"Loading {scale}x upscaler", logger=self.log.info):
 
 
126
  self.upscaler = RealESRGAN(scale, device=self.pipe.device)
127
  self.upscaler.load_weights()
 
128
  except Exception as e:
129
  self.log.error(f"Error loading {scale}x upscaler: {e}")
130
  self.upscaler = None
@@ -152,9 +169,10 @@ class Loader:
152
  self.log.info("Enabling FreeU")
153
  self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2)
154
 
155
- def _load_ip_adapter(self, ip_adapter=""):
156
  if not self.ip_adapter and ip_adapter:
157
- with timer("Loading IP-Adapter", logger=self.log.info):
 
158
  self.pipe.load_ip_adapter(
159
  "h94/IP-Adapter",
160
  subfolder="models",
@@ -194,22 +212,20 @@ class Loader:
194
  if self.pipe is not None:
195
  self.pipe.set_progress_bar_config(disable=progress is not None)
196
 
197
- def _load_vae(self, taesd=False, model=""):
198
- vae_type = type(self.pipe.vae)
199
- is_kl = issubclass(vae_type, AutoencoderKL)
200
- is_tiny = issubclass(vae_type, AutoencoderTiny)
201
-
202
  # by default all models use KL
203
- if is_kl and taesd:
204
- with timer("Loading Tiny VAE", logger=self.log.info):
 
205
  self.pipe.vae = AutoencoderTiny.from_pretrained(
206
  pretrained_model_name_or_path="madebyollin/taesd",
207
  torch_dtype=self.pipe.dtype,
208
  ).to(self.pipe.device)
209
  return
210
 
211
- if is_tiny and not taesd:
212
- with timer("Loading KL VAE", logger=self.log.info):
 
213
  if model.lower() in Config.MODEL_CHECKPOINTS.keys():
214
  self.pipe.vae = AutoencoderKL.from_single_file(
215
  f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
@@ -305,8 +321,8 @@ class Loader:
305
  if not same_scheduler or not same_karras:
306
  self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
307
 
308
- self._load_vae(taesd, model)
309
  self._load_freeu(freeu)
310
  self._load_deepcache(deepcache)
311
- self._load_ip_adapter(ip_adapter)
312
- self._load_upscaler(scale)
 
9
  from .config import Config
10
  from .logger import Logger
11
  from .upscaler import RealESRGAN
12
+ from .utils import progress_bar, timer
13
 
14
 
15
  class Loader:
 
27
  cls._instance.log = Logger("Loader")
28
  return cls._instance
29
 
30
+ @property
31
+ def _is_kl_vae(self):
32
+ if self.pipe is not None:
33
+ vae_type = type(self.pipe.vae)
34
+ return issubclass(vae_type, AutoencoderKL)
35
+ return False
36
+
37
+ @property
38
+ def _is_tiny_vae(self):
39
+ if self.pipe is not None:
40
+ vae_type = type(self.pipe.vae)
41
+ return issubclass(vae_type, AutoencoderTiny)
42
+ return False
43
+
44
  def _should_unload_upscaler(self, scale=1):
45
  if self.upscaler is not None and self.upscaler.scale != scale:
46
  return True
 
133
  setattr(self, component, None)
134
  gc.collect()
135
 
136
+ def _load_upscaler(self, scale=1, progress=None):
137
  if self.upscaler is None and scale > 1:
138
  try:
139
+ msg = f"Loading {scale}x upscaler"
140
+ # fmt: off
141
+ with timer(msg, logger=self.log.info), progress_bar(100, desc=msg, progress=progress):
142
  self.upscaler = RealESRGAN(scale, device=self.pipe.device)
143
  self.upscaler.load_weights()
144
+ # fmt: on
145
  except Exception as e:
146
  self.log.error(f"Error loading {scale}x upscaler: {e}")
147
  self.upscaler = None
 
169
  self.log.info("Enabling FreeU")
170
  self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2)
171
 
172
+ def _load_ip_adapter(self, ip_adapter="", progress=None):
173
  if not self.ip_adapter and ip_adapter:
174
+ msg = "Loading IP-Adapter"
175
+ with timer(msg, logger=self.log.info), progress_bar(100, desc=msg, progress=progress):
176
  self.pipe.load_ip_adapter(
177
  "h94/IP-Adapter",
178
  subfolder="models",
 
212
  if self.pipe is not None:
213
  self.pipe.set_progress_bar_config(disable=progress is not None)
214
 
215
+ def _load_vae(self, taesd=False, model="", progress=None):
 
 
 
 
216
  # by default all models use KL
217
+ if self._is_kl_vae and taesd:
218
+ msg = "Loading Tiny VAE"
219
+ with timer(msg, logger=self.log.info), progress_bar(100, desc=msg, progress=progress):
220
  self.pipe.vae = AutoencoderTiny.from_pretrained(
221
  pretrained_model_name_or_path="madebyollin/taesd",
222
  torch_dtype=self.pipe.dtype,
223
  ).to(self.pipe.device)
224
  return
225
 
226
+ if self._is_tiny_vae and not taesd:
227
+ msg = "Loading KL VAE"
228
+ with timer(msg, logger=self.log.info), progress_bar(100, desc=msg, progress=progress):
229
  if model.lower() in Config.MODEL_CHECKPOINTS.keys():
230
  self.pipe.vae = AutoencoderKL.from_single_file(
231
  f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
 
321
  if not same_scheduler or not same_karras:
322
  self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
323
 
324
+ self._load_vae(taesd, model, progress)
325
  self._load_freeu(freeu)
326
  self._load_deepcache(deepcache)
327
+ self._load_ip_adapter(ip_adapter, progress)
328
+ self._load_upscaler(scale, progress)
lib/logger.py CHANGED
@@ -1,5 +1,4 @@
1
  import logging
2
- from functools import wraps
3
  from threading import Lock
4
 
5
 
@@ -54,20 +53,3 @@ class Logger:
54
 
55
  def critical(self, message, **kwargs):
56
  self._log(logging.CRITICAL, message, **kwargs)
57
-
58
-
59
- # decorator for logging function calls
60
- def log_fn(name=None):
61
- def decorator(fn):
62
- @wraps(fn)
63
- def wrapper(*args, **kwargs):
64
- log = Logger(name or fn.__name__)
65
- log.info("begin")
66
- result = fn(*args, **kwargs)
67
- log.info("end")
68
-
69
- return result
70
-
71
- return wrapper
72
-
73
- return decorator
 
1
  import logging
 
2
  from threading import Lock
3
 
4
 
 
53
 
54
  def critical(self, message, **kwargs):
55
  self._log(logging.CRITICAL, message, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/utils.py CHANGED
@@ -35,6 +35,17 @@ def timer(message="Operation", logger=print):
35
  logger(f"{message} took {end - start:.2f}s")
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
38
  @functools.lru_cache()
39
  def load_json(path: str) -> dict:
40
  with open(path, "r", encoding="utf-8") as file:
 
35
  logger(f"{message} took {end - start:.2f}s")
36
 
37
 
38
+ @contextmanager
39
+ def progress_bar(total, desc="Loading", progress=None):
40
+ if progress is None:
41
+ yield
42
+ try:
43
+ progress((0, total), desc=desc)
44
+ yield
45
+ finally:
46
+ progress((total, total), desc=desc)
47
+
48
+
49
  @functools.lru_cache()
50
  def load_json(path: str) -> dict:
51
  with open(path, "r", encoding="utf-8") as file: