adamelliotfields commited on
Commit
9e8b99d
1 Parent(s): 52bf5e0

Use safe_progress helper

Browse files
Files changed (4) hide show
  1. lib/__init__.py +2 -0
  2. lib/inference.py +15 -13
  3. lib/loader.py +27 -15
  4. lib/utils.py +5 -0
lib/__init__.py CHANGED
@@ -11,6 +11,7 @@ from .utils import (
11
  enable_progress_bars,
12
  load_json,
13
  read_file,
 
14
  timer,
15
  )
16
 
@@ -27,5 +28,6 @@ __all__ = [
27
  "generate",
28
  "load_json",
29
  "read_file",
 
30
  "timer",
31
  ]
 
11
  enable_progress_bars,
12
  load_json,
13
  read_file,
14
+ safe_progress,
15
  timer,
16
  )
17
 
 
28
  "generate",
29
  "load_json",
30
  "read_file",
31
+ "safe_progress",
32
  "timer",
33
  ]
lib/inference.py CHANGED
@@ -16,7 +16,7 @@ from spaces import GPU
16
  from .config import Config
17
  from .loader import Loader
18
  from .logger import Logger
19
- from .utils import load_json, timer
20
 
21
 
22
  def parse_prompt_with_arrays(prompt: str) -> list[str]:
@@ -128,8 +128,8 @@ def generate(
128
  log = Logger("generate")
129
  log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}")
130
 
131
- if Config.ZERO_GPU and progress is not None:
132
- progress((100, 100), desc="ZeroGPU init")
133
 
134
  if not torch.cuda.is_available():
135
  raise Error("CUDA not available")
@@ -197,7 +197,7 @@ def generate(
197
  desc_loras = "Loading LoRAs"
198
  if total_loras > 0:
199
  with timer(f"Loading {total_loras} LoRA{'s' if total_loras > 1 else ''}"):
200
- progress((0, total_loras), desc=desc_loras)
201
  for i, (lora, weight) in enumerate(loras_and_weights):
202
  if lora and lora.lower() != "none" and lora not in loras:
203
  config = Config.CIVIT_LORAS.get(lora)
@@ -210,7 +210,7 @@ def generate(
210
  )
211
  weights.append(weight)
212
  loras.append(lora)
213
- progress((i + 1, total_loras), desc=desc_loras)
214
  except Exception:
215
  raise Error(f"Error loading {config['name']} LoRA")
216
 
@@ -247,6 +247,7 @@ def generate(
247
 
248
  images = []
249
  current_seed = seed
 
250
  for i in range(num_images):
251
  try:
252
  generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
@@ -299,16 +300,8 @@ def generate(
299
 
300
  try:
301
  image = pipe(**kwargs).images[0]
302
- if scale > 1:
303
- msg = f"Upscaling {scale}x"
304
- with timer(msg, logger=log.info):
305
- progress((0, 100), desc=msg)
306
- image = upscaler.predict(image)
307
- progress((100, 100), desc=msg)
308
  images.append((image, str(current_seed)))
309
  current_seed += 1
310
- except Exception as e:
311
- raise Error(f"{e}")
312
  finally:
313
  if embeddings:
314
  pipe.unload_textual_inversion()
@@ -317,6 +310,15 @@ def generate(
317
  CURRENT_STEP = 0
318
  CURRENT_IMAGE += 1
319
 
 
 
 
 
 
 
 
 
 
320
  # cleanup
321
  loader.collect()
322
  gc.collect()
 
16
  from .config import Config
17
  from .loader import Loader
18
  from .logger import Logger
19
+ from .utils import load_json, safe_progress, timer
20
 
21
 
22
  def parse_prompt_with_arrays(prompt: str) -> list[str]:
 
128
  log = Logger("generate")
129
  log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}")
130
 
131
+ if Config.ZERO_GPU:
132
+ safe_progress(progress, 100, 100, "ZeroGPU init")
133
 
134
  if not torch.cuda.is_available():
135
  raise Error("CUDA not available")
 
197
  desc_loras = "Loading LoRAs"
198
  if total_loras > 0:
199
  with timer(f"Loading {total_loras} LoRA{'s' if total_loras > 1 else ''}"):
200
+ safe_progress(progress, 0, total_loras, desc_loras)
201
  for i, (lora, weight) in enumerate(loras_and_weights):
202
  if lora and lora.lower() != "none" and lora not in loras:
203
  config = Config.CIVIT_LORAS.get(lora)
 
210
  )
211
  weights.append(weight)
212
  loras.append(lora)
213
+ safe_progress(progress, i + 1, total_loras, desc_loras)
214
  except Exception:
215
  raise Error(f"Error loading {config['name']} LoRA")
216
 
 
247
 
248
  images = []
249
  current_seed = seed
250
+ safe_progress(progress, 0, num_images, f"Generating image 0/{num_images}")
251
  for i in range(num_images):
252
  try:
253
  generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
 
300
 
301
  try:
302
  image = pipe(**kwargs).images[0]
 
 
 
 
 
 
303
  images.append((image, str(current_seed)))
304
  current_seed += 1
 
 
305
  finally:
306
  if embeddings:
307
  pipe.unload_textual_inversion()
 
310
  CURRENT_STEP = 0
311
  CURRENT_IMAGE += 1
312
 
313
+ if scale > 1:
314
+ msg = f"Upscaling {scale}x"
315
+ with timer(msg, logger=log.info):
316
+ safe_progress(progress, 0, num_images, desc=msg)
317
+ for i, image in enumerate(images):
318
+ image = upscaler.predict(image[0])
319
+ images[i] = image
320
+ safe_progress(progress, i + 1, num_images, desc=msg)
321
+
322
  # cleanup
323
  loader.collect()
324
  gc.collect()
lib/loader.py CHANGED
@@ -9,7 +9,7 @@ from diffusers.models.attention_processor import AttnProcessor2_0, IPAdapterAttn
9
  from .config import Config
10
  from .logger import Logger
11
  from .upscaler import RealESRGAN
12
- from .utils import timer
13
 
14
 
15
  class Loader:
@@ -62,6 +62,11 @@ class Loader:
62
  return True
63
  return False
64
 
 
 
 
 
 
65
  def _should_unload_ip_adapter(self, model="", ip_adapter=""):
66
  # unload if model changed
67
  if self.model and self.model.lower() != model.lower():
@@ -128,7 +133,7 @@ class Loader:
128
  if self._should_unload_deepcache(deepcache): # remove deepcache first
129
  self._unload_deepcache()
130
 
131
- if self._has_freeu and not freeu:
132
  self._unload_freeu()
133
 
134
  if self._should_unload_upscaler(scale):
@@ -154,6 +159,11 @@ class Loader:
154
  return True
155
  return False
156
 
 
 
 
 
 
157
  def _should_load_deepcache(self, interval=1):
158
  has_deepcache = hasattr(self.pipe, "deepcache")
159
  if not has_deepcache and interval != 1:
@@ -176,11 +186,9 @@ class Loader:
176
  if self._should_load_upscaler(scale):
177
  try:
178
  msg = f"Loading {scale}x upscaler"
179
- # fmt: off
180
  with timer(msg, logger=self.log.info):
181
  self.upscaler = RealESRGAN(scale, device=self.pipe.device)
182
  self.upscaler.load_weights()
183
- # fmt: on
184
  except Exception as e:
185
  self.log.error(f"Error loading {scale}x upscaler: {e}")
186
  self.upscaler = None
@@ -194,7 +202,7 @@ class Loader:
194
 
195
  # https://github.com/ChenyangSi/FreeU
196
  def _load_freeu(self, freeu=False):
197
- if not self._has_freeu and freeu:
198
  self.log.info("Enabling FreeU")
199
  self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2)
200
 
@@ -355,30 +363,34 @@ class Loader:
355
  [
356
  self._is_kl_vae and taesd,
357
  self._is_tiny_vae and not taesd,
358
- not self._has_freeu and freeu,
359
  self._should_load_deepcache(deepcache),
360
  self._should_load_ip_adapter(ip_adapter),
361
  self._should_load_upscaler(scale),
362
  ]
363
  )
364
 
365
- msg = "Loading additional features"
366
- if self._is_kl_vae and taesd or self._is_tiny_vae and not taesd:
367
- self._load_vae(taesd, model)
368
- progress((CURRENT_STEP, TOTAL_STEPS), desc=msg)
369
- CURRENT_STEP += 1
370
  if not self._has_freeu and freeu:
371
  self._load_freeu(freeu)
372
- progress((CURRENT_STEP, TOTAL_STEPS), desc=msg)
373
  CURRENT_STEP += 1
 
374
  if self._should_load_deepcache(deepcache):
375
  self._load_deepcache(deepcache)
376
- progress((CURRENT_STEP, TOTAL_STEPS), desc=msg)
377
  CURRENT_STEP += 1
 
378
  if self._should_load_ip_adapter(ip_adapter):
379
  self._load_ip_adapter(ip_adapter)
380
- progress((CURRENT_STEP, TOTAL_STEPS), desc=msg)
381
  CURRENT_STEP += 1
 
382
  if self._should_load_upscaler(scale):
383
  self._load_upscaler(scale)
384
- progress((CURRENT_STEP, TOTAL_STEPS), desc=msg)
 
 
 
 
 
 
9
  from .config import Config
10
  from .logger import Logger
11
  from .upscaler import RealESRGAN
12
+ from .utils import safe_progress, timer
13
 
14
 
15
  class Loader:
 
62
  return True
63
  return False
64
 
65
+ def _should_unload_freeu(self, freeu=False):
66
+ if self._has_freeu and not freeu:
67
+ return True
68
+ return False
69
+
70
  def _should_unload_ip_adapter(self, model="", ip_adapter=""):
71
  # unload if model changed
72
  if self.model and self.model.lower() != model.lower():
 
133
  if self._should_unload_deepcache(deepcache): # remove deepcache first
134
  self._unload_deepcache()
135
 
136
+ if self._should_unload_freeu(freeu):
137
  self._unload_freeu()
138
 
139
  if self._should_unload_upscaler(scale):
 
159
  return True
160
  return False
161
 
162
+ def _should_load_freeu(self, freeu=False):
163
+ if not self._has_freeu and freeu:
164
+ return True
165
+ return False
166
+
167
  def _should_load_deepcache(self, interval=1):
168
  has_deepcache = hasattr(self.pipe, "deepcache")
169
  if not has_deepcache and interval != 1:
 
186
  if self._should_load_upscaler(scale):
187
  try:
188
  msg = f"Loading {scale}x upscaler"
 
189
  with timer(msg, logger=self.log.info):
190
  self.upscaler = RealESRGAN(scale, device=self.pipe.device)
191
  self.upscaler.load_weights()
 
192
  except Exception as e:
193
  self.log.error(f"Error loading {scale}x upscaler: {e}")
194
  self.upscaler = None
 
202
 
203
  # https://github.com/ChenyangSi/FreeU
204
  def _load_freeu(self, freeu=False):
205
+ if self._should_load_freeu(freeu):
206
  self.log.info("Enabling FreeU")
207
  self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2)
208
 
 
363
  [
364
  self._is_kl_vae and taesd,
365
  self._is_tiny_vae and not taesd,
366
+ self._should_load_freeu(freeu),
367
  self._should_load_deepcache(deepcache),
368
  self._should_load_ip_adapter(ip_adapter),
369
  self._should_load_upscaler(scale),
370
  ]
371
  )
372
 
373
+ desc = "Configuring pipeline"
 
 
 
 
374
  if not self._has_freeu and freeu:
375
  self._load_freeu(freeu)
376
+ safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)
377
  CURRENT_STEP += 1
378
+
379
  if self._should_load_deepcache(deepcache):
380
  self._load_deepcache(deepcache)
381
+ safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)
382
  CURRENT_STEP += 1
383
+
384
  if self._should_load_ip_adapter(ip_adapter):
385
  self._load_ip_adapter(ip_adapter)
386
+ safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)
387
  CURRENT_STEP += 1
388
+
389
  if self._should_load_upscaler(scale):
390
  self._load_upscaler(scale)
391
+ safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)
392
+ CURRENT_STEP += 1
393
+
394
+ if self._is_kl_vae and taesd or self._is_tiny_vae and not taesd:
395
+ self._load_vae(taesd, model)
396
+ safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)
lib/utils.py CHANGED
@@ -58,6 +58,11 @@ def enable_progress_bars():
58
  diffusers_logging.enable_progress_bar()
59
 
60
 
 
 
 
 
 
61
  def download_repo_files(repo_id, allow_patterns, token=None):
62
  was_disabled = are_progress_bars_disabled()
63
  enable_progress_bars()
 
58
  diffusers_logging.enable_progress_bar()
59
 
60
 
61
+ def safe_progress(progress, current=0, total=0, desc=""):
62
+ if progress is not None:
63
+ progress((current, total), desc=desc)
64
+
65
+
66
  def download_repo_files(repo_id, allow_patterns, token=None):
67
  was_disabled = are_progress_bars_disabled()
68
  enable_progress_bars()