Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Commit
•
d179c4c
1
Parent(s):
1e250ff
Terminal progress bar improvements
Browse files- app.py +10 -14
- lib/__init__.py +11 -1
- lib/config.py +40 -17
- lib/utils.py +26 -3
app.py
CHANGED
@@ -2,22 +2,18 @@ import argparse
|
|
2 |
import json
|
3 |
import os
|
4 |
import random
|
5 |
-
from warnings import filterwarnings
|
6 |
|
7 |
import gradio as gr
|
8 |
-
from diffusers.utils import logging as diffusers_logging
|
9 |
-
from transformers import logging as transformers_logging
|
10 |
|
11 |
-
from lib import
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
transformers_logging.disable_progress_bar()
|
21 |
|
22 |
# the CSS `content` attribute expects a string so we need to wrap the number in quotes
|
23 |
refresh_seed_js = """
|
@@ -567,7 +563,7 @@ if __name__ == "__main__":
|
|
567 |
parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
|
568 |
args = parser.parse_args()
|
569 |
|
570 |
-
|
571 |
for repo_id, allow_patterns in Config.HF_MODELS.items():
|
572 |
download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
|
573 |
|
|
|
2 |
import json
|
3 |
import os
|
4 |
import random
|
|
|
5 |
|
6 |
import gradio as gr
|
|
|
|
|
7 |
|
8 |
+
from lib import (
|
9 |
+
Config,
|
10 |
+
async_call,
|
11 |
+
disable_progress_bars,
|
12 |
+
download_civit_file,
|
13 |
+
download_repo_files,
|
14 |
+
generate,
|
15 |
+
read_file,
|
16 |
+
)
|
|
|
17 |
|
18 |
# the CSS `content` attribute expects a string so we need to wrap the number in quotes
|
19 |
refresh_seed_js = """
|
|
|
563 |
parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
|
564 |
args = parser.parse_args()
|
565 |
|
566 |
+
disable_progress_bars()
|
567 |
for repo_id, allow_patterns in Config.HF_MODELS.items():
|
568 |
download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
|
569 |
|
lib/__init__.py
CHANGED
@@ -3,7 +3,15 @@ from .inference import generate
|
|
3 |
from .loader import Loader
|
4 |
from .logger import Logger, log_fn
|
5 |
from .upscaler import RealESRGAN
|
6 |
-
from .utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
__all__ = [
|
9 |
"Config",
|
@@ -11,8 +19,10 @@ __all__ = [
|
|
11 |
"Logger",
|
12 |
"RealESRGAN",
|
13 |
"async_call",
|
|
|
14 |
"download_civit_file",
|
15 |
"download_repo_files",
|
|
|
16 |
"generate",
|
17 |
"load_json",
|
18 |
"log_fn",
|
|
|
3 |
from .loader import Loader
|
4 |
from .logger import Logger, log_fn
|
5 |
from .upscaler import RealESRGAN
|
6 |
+
from .utils import (
|
7 |
+
async_call,
|
8 |
+
disable_progress_bars,
|
9 |
+
download_civit_file,
|
10 |
+
download_repo_files,
|
11 |
+
enable_progress_bars,
|
12 |
+
load_json,
|
13 |
+
read_file,
|
14 |
+
)
|
15 |
|
16 |
__all__ = [
|
17 |
"Config",
|
|
|
19 |
"Logger",
|
20 |
"RealESRGAN",
|
21 |
"async_call",
|
22 |
+
"disable_progress_bars",
|
23 |
"download_civit_file",
|
24 |
"download_repo_files",
|
25 |
+
"enable_progress_bars",
|
26 |
"generate",
|
27 |
"load_json",
|
28 |
"log_fn",
|
lib/config.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import os
|
2 |
from importlib import import_module
|
|
|
3 |
from types import SimpleNamespace
|
|
|
4 |
|
5 |
from diffusers import (
|
6 |
DDIMScheduler,
|
@@ -11,33 +13,54 @@ from diffusers import (
|
|
11 |
PNDMScheduler,
|
12 |
UniPCMultistepScheduler,
|
13 |
)
|
|
|
|
|
14 |
|
15 |
from .pipelines import CustomStableDiffusionImg2ImgPipeline, CustomStableDiffusionPipeline
|
16 |
|
17 |
# improved GPU handling and progress bars; set before importing spaces
|
18 |
-
os.environ["ZEROGPU_V2"] = "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
Config = SimpleNamespace(
|
21 |
HF_TOKEN=os.environ.get("HF_TOKEN", None),
|
22 |
CIVIT_TOKEN=os.environ.get("CIVIT_TOKEN", None),
|
23 |
ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
|
24 |
HF_MODELS={
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
"tokenizer/vocab.json",
|
35 |
-
"unet/config.json",
|
36 |
-
"unet/diffusion_pytorch_model.fp16.safetensors",
|
37 |
-
"vae/config.json",
|
38 |
-
"vae/diffusion_pytorch_model.fp16.safetensors",
|
39 |
-
"model_index.json",
|
40 |
-
],
|
41 |
},
|
42 |
CIVIT_LORAS={
|
43 |
# https://civitai.com/models/411088?modelVersionId=486099
|
|
|
1 |
import os
|
2 |
from importlib import import_module
|
3 |
+
from importlib.util import find_spec
|
4 |
from types import SimpleNamespace
|
5 |
+
from warnings import filterwarnings
|
6 |
|
7 |
from diffusers import (
|
8 |
DDIMScheduler,
|
|
|
13 |
PNDMScheduler,
|
14 |
UniPCMultistepScheduler,
|
15 |
)
|
16 |
+
from diffusers.utils import logging as diffusers_logging
|
17 |
+
from transformers import logging as transformers_logging
|
18 |
|
19 |
from .pipelines import CustomStableDiffusionImg2ImgPipeline, CustomStableDiffusionPipeline
|
20 |
|
21 |
# improved GPU handling and progress bars; set before importing spaces
|
22 |
+
os.environ["ZEROGPU_V2"] = "1"
|
23 |
+
|
24 |
+
if find_spec("hf_transfer"):
|
25 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
26 |
+
|
27 |
+
filterwarnings("ignore", category=FutureWarning, module="diffusers")
|
28 |
+
filterwarnings("ignore", category=FutureWarning, module="transformers")
|
29 |
+
|
30 |
+
diffusers_logging.set_verbosity_error()
|
31 |
+
transformers_logging.set_verbosity_error()
|
32 |
+
|
33 |
+
_sd_files = [
|
34 |
+
"feature_extractor/preprocessor_config.json",
|
35 |
+
"safety_checker/config.json",
|
36 |
+
"scheduler/scheduler_config.json",
|
37 |
+
"text_encoder/config.json",
|
38 |
+
"text_encoder/model.fp16.safetensors",
|
39 |
+
"tokenizer/merges.txt",
|
40 |
+
"tokenizer/special_tokens_map.json",
|
41 |
+
"tokenizer/tokenizer_config.json",
|
42 |
+
"tokenizer/vocab.json",
|
43 |
+
"unet/config.json",
|
44 |
+
"unet/diffusion_pytorch_model.fp16.safetensors",
|
45 |
+
"vae/config.json",
|
46 |
+
"vae/diffusion_pytorch_model.fp16.safetensors",
|
47 |
+
"model_index.json",
|
48 |
+
]
|
49 |
|
50 |
Config = SimpleNamespace(
|
51 |
HF_TOKEN=os.environ.get("HF_TOKEN", None),
|
52 |
CIVIT_TOKEN=os.environ.get("CIVIT_TOKEN", None),
|
53 |
ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
|
54 |
HF_MODELS={
|
55 |
+
# downloaded on startup
|
56 |
+
"Lykon/dreamshaper-8": [*_sd_files],
|
57 |
+
"Comfy-Org/stable-diffusion-v1-5-archive": ["v1-5-pruned-emaonly-fp16.safetensors"],
|
58 |
+
"cyberdelia/CyberRealistic": ["CyberRealistic_V5_FP16.safetensors"],
|
59 |
+
"fluently/Fluently-v4": ["Fluently-v4.safetensors"],
|
60 |
+
"Linaqruf/anything-v3-1": ["anything-v3-2.safetensors"],
|
61 |
+
"prompthero/openjourney-v4": ["openjourney-v4.ckpt"],
|
62 |
+
"SG161222/Realistic_Vision_V5.1_noVAE": ["Realistic_Vision_V5.1_fp16-no-ema.safetensors"],
|
63 |
+
"XpucT/Deliberate": ["Deliberate_v6.safetensors"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
},
|
65 |
CIVIT_LORAS={
|
66 |
# https://civitai.com/models/411088?modelVersionId=486099
|
lib/utils.py
CHANGED
@@ -7,15 +7,22 @@ from typing import Callable, TypeVar
|
|
7 |
import anyio
|
8 |
import httpx
|
9 |
from anyio import Semaphore
|
|
|
10 |
from huggingface_hub._snapshot_download import snapshot_download
|
|
|
|
|
11 |
from typing_extensions import ParamSpec
|
12 |
|
|
|
|
|
13 |
T = TypeVar("T")
|
14 |
P = ParamSpec("P")
|
15 |
|
16 |
MAX_CONCURRENT_THREADS = 1
|
17 |
MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
|
18 |
|
|
|
|
|
19 |
|
20 |
@functools.lru_cache()
|
21 |
def load_json(path: str) -> dict:
|
@@ -29,8 +36,21 @@ def read_file(path: str) -> str:
|
|
29 |
return file.read()
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
def download_repo_files(repo_id, allow_patterns, token=None):
|
33 |
-
|
|
|
|
|
34 |
repo_id=repo_id,
|
35 |
repo_type="model",
|
36 |
revision="main",
|
@@ -38,6 +58,9 @@ def download_repo_files(repo_id, allow_patterns, token=None):
|
|
38 |
allow_patterns=allow_patterns,
|
39 |
ignore_patterns=None,
|
40 |
)
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
def download_civit_file(lora_id, version_id, file_path=".", token=None):
|
@@ -62,9 +85,9 @@ def download_civit_file(lora_id, version_id, file_path=".", token=None):
|
|
62 |
with open(file, "wb") as f:
|
63 |
f.write(response.content)
|
64 |
except httpx.HTTPStatusError as e:
|
65 |
-
|
66 |
except httpx.RequestError as e:
|
67 |
-
|
68 |
|
69 |
|
70 |
# like the original but supports args and kwargs instead of a dict
|
|
|
7 |
import anyio
|
8 |
import httpx
|
9 |
from anyio import Semaphore
|
10 |
+
from diffusers.utils import logging as diffusers_logging
|
11 |
from huggingface_hub._snapshot_download import snapshot_download
|
12 |
+
from huggingface_hub.utils import are_progress_bars_disabled
|
13 |
+
from transformers import logging as transformers_logging
|
14 |
from typing_extensions import ParamSpec
|
15 |
|
16 |
+
from .logger import Logger
|
17 |
+
|
18 |
T = TypeVar("T")
|
19 |
P = ParamSpec("P")
|
20 |
|
21 |
MAX_CONCURRENT_THREADS = 1
|
22 |
MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
|
23 |
|
24 |
+
log = Logger("utils")
|
25 |
+
|
26 |
|
27 |
@functools.lru_cache()
|
28 |
def load_json(path: str) -> dict:
|
|
|
36 |
return file.read()
|
37 |
|
38 |
|
39 |
+
def disable_progress_bars():
|
40 |
+
transformers_logging.disable_progress_bar()
|
41 |
+
diffusers_logging.disable_progress_bar()
|
42 |
+
|
43 |
+
|
44 |
+
def enable_progress_bars():
|
45 |
+
# warns if `HF_HUB_DISABLE_PROGRESS_BARS` env var is not None
|
46 |
+
transformers_logging.enable_progress_bar()
|
47 |
+
diffusers_logging.enable_progress_bar()
|
48 |
+
|
49 |
+
|
50 |
def download_repo_files(repo_id, allow_patterns, token=None):
|
51 |
+
was_disabled = are_progress_bars_disabled()
|
52 |
+
enable_progress_bars()
|
53 |
+
snapshot_path = snapshot_download(
|
54 |
repo_id=repo_id,
|
55 |
repo_type="model",
|
56 |
revision="main",
|
|
|
58 |
allow_patterns=allow_patterns,
|
59 |
ignore_patterns=None,
|
60 |
)
|
61 |
+
if was_disabled:
|
62 |
+
disable_progress_bars()
|
63 |
+
return snapshot_path
|
64 |
|
65 |
|
66 |
def download_civit_file(lora_id, version_id, file_path=".", token=None):
|
|
|
85 |
with open(file, "wb") as f:
|
86 |
f.write(response.content)
|
87 |
except httpx.HTTPStatusError as e:
|
88 |
+
log.error(f"{e.response.status_code} {e.response.text}")
|
89 |
except httpx.RequestError as e:
|
90 |
+
log.error(f"RequestError: {e}")
|
91 |
|
92 |
|
93 |
# like the original but supports args and kwargs instead of a dict
|