Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Commit
•
13b498b
1
Parent(s):
069fc81
Add progress_bar context manager
Browse files- lib/__init__.py +3 -2
- lib/inference.py +24 -16
- lib/loader.py +33 -17
- lib/logger.py +0 -18
- lib/utils.py +11 -0
lib/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from .config import Config
|
2 |
from .inference import generate
|
3 |
from .loader import Loader
|
4 |
-
from .logger import Logger
|
5 |
from .upscaler import RealESRGAN
|
6 |
from .utils import (
|
7 |
async_call,
|
@@ -10,6 +10,7 @@ from .utils import (
|
|
10 |
download_repo_files,
|
11 |
enable_progress_bars,
|
12 |
load_json,
|
|
|
13 |
read_file,
|
14 |
timer,
|
15 |
)
|
@@ -26,7 +27,7 @@ __all__ = [
|
|
26 |
"enable_progress_bars",
|
27 |
"generate",
|
28 |
"load_json",
|
29 |
-
"
|
30 |
"read_file",
|
31 |
"timer",
|
32 |
]
|
|
|
1 |
from .config import Config
|
2 |
from .inference import generate
|
3 |
from .loader import Loader
|
4 |
+
from .logger import Logger
|
5 |
from .upscaler import RealESRGAN
|
6 |
from .utils import (
|
7 |
async_call,
|
|
|
10 |
download_repo_files,
|
11 |
enable_progress_bars,
|
12 |
load_json,
|
13 |
+
progress_bar,
|
14 |
read_file,
|
15 |
timer,
|
16 |
)
|
|
|
27 |
"enable_progress_bars",
|
28 |
"generate",
|
29 |
"load_json",
|
30 |
+
"progress_bar",
|
31 |
"read_file",
|
32 |
"timer",
|
33 |
]
|
lib/inference.py
CHANGED
@@ -16,7 +16,7 @@ from spaces import GPU
|
|
16 |
from .config import Config
|
17 |
from .loader import Loader
|
18 |
from .logger import Logger
|
19 |
-
from .utils import load_json
|
20 |
|
21 |
|
22 |
def parse_prompt_with_arrays(prompt: str) -> list[str]:
|
@@ -193,20 +193,26 @@ def generate(
|
|
193 |
weights = []
|
194 |
loras_and_weights = [(lora_1, lora_1_weight), (lora_2, lora_2_weight)]
|
195 |
loras_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "loras"))
|
196 |
-
for lora,
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
# unload after generating or if there was an error
|
212 |
try:
|
@@ -294,7 +300,9 @@ def generate(
|
|
294 |
try:
|
295 |
image = pipe(**kwargs).images[0]
|
296 |
if scale > 1:
|
297 |
-
|
|
|
|
|
298 |
images.append((image, str(current_seed)))
|
299 |
current_seed += 1
|
300 |
except Exception as e:
|
|
|
16 |
from .config import Config
|
17 |
from .loader import Loader
|
18 |
from .logger import Logger
|
19 |
+
from .utils import load_json, progress_bar, timer
|
20 |
|
21 |
|
22 |
def parse_prompt_with_arrays(prompt: str) -> list[str]:
|
|
|
193 |
weights = []
|
194 |
loras_and_weights = [(lora_1, lora_1_weight), (lora_2, lora_2_weight)]
|
195 |
loras_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "loras"))
|
196 |
+
total_loras = sum(1 for lora, _ in loras_and_weights if lora and lora.lower() != "none")
|
197 |
+
desc_loras = "Loading LoRAs"
|
198 |
+
if total_loras > 0:
|
199 |
+
with timer(f"Loading {total_loras} LoRA{'s' if total_loras > 1 else ''}"):
|
200 |
+
progress((0, total_loras), desc=desc_loras)
|
201 |
+
for i, (lora, weight) in enumerate(loras_and_weights):
|
202 |
+
if lora and lora.lower() != "none" and lora not in loras:
|
203 |
+
config = Config.CIVIT_LORAS.get(lora)
|
204 |
+
if config:
|
205 |
+
try:
|
206 |
+
pipe.load_lora_weights(
|
207 |
+
loras_dir,
|
208 |
+
adapter_name=lora,
|
209 |
+
weight_name=f"{lora}.{config['model_version_id']}.safetensors",
|
210 |
+
)
|
211 |
+
weights.append(weight)
|
212 |
+
loras.append(lora)
|
213 |
+
progress((i + 1, total_loras), desc=desc_loras)
|
214 |
+
except Exception:
|
215 |
+
raise Error(f"Error loading {config['name']} LoRA")
|
216 |
|
217 |
# unload after generating or if there was an error
|
218 |
try:
|
|
|
300 |
try:
|
301 |
image = pipe(**kwargs).images[0]
|
302 |
if scale > 1:
|
303 |
+
msg = f"Upscaling {scale}x"
|
304 |
+
with timer(msg, logger=log.info), progress_bar(100, desc=msg, progress=progress):
|
305 |
+
image = upscaler.predict(image)
|
306 |
images.append((image, str(current_seed)))
|
307 |
current_seed += 1
|
308 |
except Exception as e:
|
lib/loader.py
CHANGED
@@ -9,7 +9,7 @@ from diffusers.models.attention_processor import AttnProcessor2_0, IPAdapterAttn
|
|
9 |
from .config import Config
|
10 |
from .logger import Logger
|
11 |
from .upscaler import RealESRGAN
|
12 |
-
from .utils import timer
|
13 |
|
14 |
|
15 |
class Loader:
|
@@ -27,6 +27,20 @@ class Loader:
|
|
27 |
cls._instance.log = Logger("Loader")
|
28 |
return cls._instance
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def _should_unload_upscaler(self, scale=1):
|
31 |
if self.upscaler is not None and self.upscaler.scale != scale:
|
32 |
return True
|
@@ -119,12 +133,15 @@ class Loader:
|
|
119 |
setattr(self, component, None)
|
120 |
gc.collect()
|
121 |
|
122 |
-
def _load_upscaler(self, scale=1):
|
123 |
if self.upscaler is None and scale > 1:
|
124 |
try:
|
125 |
-
|
|
|
|
|
126 |
self.upscaler = RealESRGAN(scale, device=self.pipe.device)
|
127 |
self.upscaler.load_weights()
|
|
|
128 |
except Exception as e:
|
129 |
self.log.error(f"Error loading {scale}x upscaler: {e}")
|
130 |
self.upscaler = None
|
@@ -152,9 +169,10 @@ class Loader:
|
|
152 |
self.log.info("Enabling FreeU")
|
153 |
self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2)
|
154 |
|
155 |
-
def _load_ip_adapter(self, ip_adapter=""):
|
156 |
if not self.ip_adapter and ip_adapter:
|
157 |
-
|
|
|
158 |
self.pipe.load_ip_adapter(
|
159 |
"h94/IP-Adapter",
|
160 |
subfolder="models",
|
@@ -194,22 +212,20 @@ class Loader:
|
|
194 |
if self.pipe is not None:
|
195 |
self.pipe.set_progress_bar_config(disable=progress is not None)
|
196 |
|
197 |
-
def _load_vae(self, taesd=False, model=""):
|
198 |
-
vae_type = type(self.pipe.vae)
|
199 |
-
is_kl = issubclass(vae_type, AutoencoderKL)
|
200 |
-
is_tiny = issubclass(vae_type, AutoencoderTiny)
|
201 |
-
|
202 |
# by default all models use KL
|
203 |
-
if
|
204 |
-
|
|
|
205 |
self.pipe.vae = AutoencoderTiny.from_pretrained(
|
206 |
pretrained_model_name_or_path="madebyollin/taesd",
|
207 |
torch_dtype=self.pipe.dtype,
|
208 |
).to(self.pipe.device)
|
209 |
return
|
210 |
|
211 |
-
if
|
212 |
-
|
|
|
213 |
if model.lower() in Config.MODEL_CHECKPOINTS.keys():
|
214 |
self.pipe.vae = AutoencoderKL.from_single_file(
|
215 |
f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
|
@@ -305,8 +321,8 @@ class Loader:
|
|
305 |
if not same_scheduler or not same_karras:
|
306 |
self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
|
307 |
|
308 |
-
self._load_vae(taesd, model)
|
309 |
self._load_freeu(freeu)
|
310 |
self._load_deepcache(deepcache)
|
311 |
-
self._load_ip_adapter(ip_adapter)
|
312 |
-
self._load_upscaler(scale)
|
|
|
9 |
from .config import Config
|
10 |
from .logger import Logger
|
11 |
from .upscaler import RealESRGAN
|
12 |
+
from .utils import progress_bar, timer
|
13 |
|
14 |
|
15 |
class Loader:
|
|
|
27 |
cls._instance.log = Logger("Loader")
|
28 |
return cls._instance
|
29 |
|
30 |
+
@property
|
31 |
+
def _is_kl_vae(self):
|
32 |
+
if self.pipe is not None:
|
33 |
+
vae_type = type(self.pipe.vae)
|
34 |
+
return issubclass(vae_type, AutoencoderKL)
|
35 |
+
return False
|
36 |
+
|
37 |
+
@property
|
38 |
+
def _is_tiny_vae(self):
|
39 |
+
if self.pipe is not None:
|
40 |
+
vae_type = type(self.pipe.vae)
|
41 |
+
return issubclass(vae_type, AutoencoderTiny)
|
42 |
+
return False
|
43 |
+
|
44 |
def _should_unload_upscaler(self, scale=1):
|
45 |
if self.upscaler is not None and self.upscaler.scale != scale:
|
46 |
return True
|
|
|
133 |
setattr(self, component, None)
|
134 |
gc.collect()
|
135 |
|
136 |
+
def _load_upscaler(self, scale=1, progress=None):
|
137 |
if self.upscaler is None and scale > 1:
|
138 |
try:
|
139 |
+
msg = f"Loading {scale}x upscaler"
|
140 |
+
# fmt: off
|
141 |
+
with timer(msg, logger=self.log.info), progress_bar(100, desc=msg, progress=progress):
|
142 |
self.upscaler = RealESRGAN(scale, device=self.pipe.device)
|
143 |
self.upscaler.load_weights()
|
144 |
+
# fmt: on
|
145 |
except Exception as e:
|
146 |
self.log.error(f"Error loading {scale}x upscaler: {e}")
|
147 |
self.upscaler = None
|
|
|
169 |
self.log.info("Enabling FreeU")
|
170 |
self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2)
|
171 |
|
172 |
+
def _load_ip_adapter(self, ip_adapter="", progress=None):
|
173 |
if not self.ip_adapter and ip_adapter:
|
174 |
+
msg = "Loading IP-Adapter"
|
175 |
+
with timer(msg, logger=self.log.info), progress_bar(100, desc=msg, progress=progress):
|
176 |
self.pipe.load_ip_adapter(
|
177 |
"h94/IP-Adapter",
|
178 |
subfolder="models",
|
|
|
212 |
if self.pipe is not None:
|
213 |
self.pipe.set_progress_bar_config(disable=progress is not None)
|
214 |
|
215 |
+
def _load_vae(self, taesd=False, model="", progress=None):
|
|
|
|
|
|
|
|
|
216 |
# by default all models use KL
|
217 |
+
if self._is_kl_vae and taesd:
|
218 |
+
msg = "Loading Tiny VAE"
|
219 |
+
with timer(msg, logger=self.log.info), progress_bar(100, desc=msg, progress=progress):
|
220 |
self.pipe.vae = AutoencoderTiny.from_pretrained(
|
221 |
pretrained_model_name_or_path="madebyollin/taesd",
|
222 |
torch_dtype=self.pipe.dtype,
|
223 |
).to(self.pipe.device)
|
224 |
return
|
225 |
|
226 |
+
if self._is_tiny_vae and not taesd:
|
227 |
+
msg = "Loading KL VAE"
|
228 |
+
with timer(msg, logger=self.log.info), progress_bar(100, desc=msg, progress=progress):
|
229 |
if model.lower() in Config.MODEL_CHECKPOINTS.keys():
|
230 |
self.pipe.vae = AutoencoderKL.from_single_file(
|
231 |
f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
|
|
|
321 |
if not same_scheduler or not same_karras:
|
322 |
self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
|
323 |
|
324 |
+
self._load_vae(taesd, model, progress)
|
325 |
self._load_freeu(freeu)
|
326 |
self._load_deepcache(deepcache)
|
327 |
+
self._load_ip_adapter(ip_adapter, progress)
|
328 |
+
self._load_upscaler(scale, progress)
|
lib/logger.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import logging
|
2 |
-
from functools import wraps
|
3 |
from threading import Lock
|
4 |
|
5 |
|
@@ -54,20 +53,3 @@ class Logger:
|
|
54 |
|
55 |
def critical(self, message, **kwargs):
|
56 |
self._log(logging.CRITICAL, message, **kwargs)
|
57 |
-
|
58 |
-
|
59 |
-
# decorator for logging function calls
|
60 |
-
def log_fn(name=None):
|
61 |
-
def decorator(fn):
|
62 |
-
@wraps(fn)
|
63 |
-
def wrapper(*args, **kwargs):
|
64 |
-
log = Logger(name or fn.__name__)
|
65 |
-
log.info("begin")
|
66 |
-
result = fn(*args, **kwargs)
|
67 |
-
log.info("end")
|
68 |
-
|
69 |
-
return result
|
70 |
-
|
71 |
-
return wrapper
|
72 |
-
|
73 |
-
return decorator
|
|
|
1 |
import logging
|
|
|
2 |
from threading import Lock
|
3 |
|
4 |
|
|
|
53 |
|
54 |
def critical(self, message, **kwargs):
|
55 |
self._log(logging.CRITICAL, message, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lib/utils.py
CHANGED
@@ -35,6 +35,17 @@ def timer(message="Operation", logger=print):
|
|
35 |
logger(f"{message} took {end - start:.2f}s")
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
@functools.lru_cache()
|
39 |
def load_json(path: str) -> dict:
|
40 |
with open(path, "r", encoding="utf-8") as file:
|
|
|
35 |
logger(f"{message} took {end - start:.2f}s")
|
36 |
|
37 |
|
38 |
+
@contextmanager
|
39 |
+
def progress_bar(total, desc="Loading", progress=None):
|
40 |
+
if progress is None:
|
41 |
+
yield
|
42 |
+
try:
|
43 |
+
progress((0, total), desc=desc)
|
44 |
+
yield
|
45 |
+
finally:
|
46 |
+
progress((total, total), desc=desc)
|
47 |
+
|
48 |
+
|
49 |
@functools.lru_cache()
|
50 |
def load_json(path: str) -> dict:
|
51 |
with open(path, "r", encoding="utf-8") as file:
|