adamelliotfields commited on
Commit
d179c4c
1 Parent(s): 1e250ff

Terminal progress bar improvements

Browse files
Files changed (4) hide show
  1. app.py +10 -14
  2. lib/__init__.py +11 -1
  3. lib/config.py +40 -17
  4. lib/utils.py +26 -3
app.py CHANGED
@@ -2,22 +2,18 @@ import argparse
2
  import json
3
  import os
4
  import random
5
- from warnings import filterwarnings
6
 
7
  import gradio as gr
8
- from diffusers.utils import logging as diffusers_logging
9
- from transformers import logging as transformers_logging
10
 
11
- from lib import Config, async_call, download_civit_file, download_repo_files, generate, read_file
12
-
13
- filterwarnings("ignore", category=FutureWarning, module="diffusers")
14
- filterwarnings("ignore", category=FutureWarning, module="transformers")
15
-
16
- diffusers_logging.set_verbosity_error()
17
- transformers_logging.set_verbosity_error()
18
-
19
- diffusers_logging.disable_progress_bar()
20
- transformers_logging.disable_progress_bar()
21
 
22
  # the CSS `content` attribute expects a string so we need to wrap the number in quotes
23
  refresh_seed_js = """
@@ -567,7 +563,7 @@ if __name__ == "__main__":
567
  parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
568
  args = parser.parse_args()
569
 
570
- # download to hub cache
571
  for repo_id, allow_patterns in Config.HF_MODELS.items():
572
  download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
573
 
 
2
  import json
3
  import os
4
  import random
 
5
 
6
  import gradio as gr
 
 
7
 
8
+ from lib import (
9
+ Config,
10
+ async_call,
11
+ disable_progress_bars,
12
+ download_civit_file,
13
+ download_repo_files,
14
+ generate,
15
+ read_file,
16
+ )
 
17
 
18
  # the CSS `content` attribute expects a string so we need to wrap the number in quotes
19
  refresh_seed_js = """
 
563
  parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
564
  args = parser.parse_args()
565
 
566
+ disable_progress_bars()
567
  for repo_id, allow_patterns in Config.HF_MODELS.items():
568
  download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
569
 
lib/__init__.py CHANGED
@@ -3,7 +3,15 @@ from .inference import generate
3
  from .loader import Loader
4
  from .logger import Logger, log_fn
5
  from .upscaler import RealESRGAN
6
- from .utils import async_call, download_civit_file, download_repo_files, load_json, read_file
 
 
 
 
 
 
 
 
7
 
8
  __all__ = [
9
  "Config",
@@ -11,8 +19,10 @@ __all__ = [
11
  "Logger",
12
  "RealESRGAN",
13
  "async_call",
 
14
  "download_civit_file",
15
  "download_repo_files",
 
16
  "generate",
17
  "load_json",
18
  "log_fn",
 
3
  from .loader import Loader
4
  from .logger import Logger, log_fn
5
  from .upscaler import RealESRGAN
6
+ from .utils import (
7
+ async_call,
8
+ disable_progress_bars,
9
+ download_civit_file,
10
+ download_repo_files,
11
+ enable_progress_bars,
12
+ load_json,
13
+ read_file,
14
+ )
15
 
16
  __all__ = [
17
  "Config",
 
19
  "Logger",
20
  "RealESRGAN",
21
  "async_call",
22
+ "disable_progress_bars",
23
  "download_civit_file",
24
  "download_repo_files",
25
+ "enable_progress_bars",
26
  "generate",
27
  "load_json",
28
  "log_fn",
lib/config.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  from importlib import import_module
 
3
  from types import SimpleNamespace
 
4
 
5
  from diffusers import (
6
  DDIMScheduler,
@@ -11,33 +13,54 @@ from diffusers import (
11
  PNDMScheduler,
12
  UniPCMultistepScheduler,
13
  )
 
 
14
 
15
  from .pipelines import CustomStableDiffusionImg2ImgPipeline, CustomStableDiffusionPipeline
16
 
17
  # improved GPU handling and progress bars; set before importing spaces
18
- os.environ["ZEROGPU_V2"] = "true"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  Config = SimpleNamespace(
21
  HF_TOKEN=os.environ.get("HF_TOKEN", None),
22
  CIVIT_TOKEN=os.environ.get("CIVIT_TOKEN", None),
23
  ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
24
  HF_MODELS={
25
- "Lykon/dreamshaper-8": [
26
- "feature_extractor/preprocessor_config.json",
27
- "safety_checker/config.json",
28
- "scheduler/scheduler_config.json",
29
- "text_encoder/config.json",
30
- "text_encoder/model.fp16.safetensors",
31
- "tokenizer/merges.txt",
32
- "tokenizer/special_tokens_map.json",
33
- "tokenizer/tokenizer_config.json",
34
- "tokenizer/vocab.json",
35
- "unet/config.json",
36
- "unet/diffusion_pytorch_model.fp16.safetensors",
37
- "vae/config.json",
38
- "vae/diffusion_pytorch_model.fp16.safetensors",
39
- "model_index.json",
40
- ],
41
  },
42
  CIVIT_LORAS={
43
  # https://civitai.com/models/411088?modelVersionId=486099
 
1
  import os
2
  from importlib import import_module
3
+ from importlib.util import find_spec
4
  from types import SimpleNamespace
5
+ from warnings import filterwarnings
6
 
7
  from diffusers import (
8
  DDIMScheduler,
 
13
  PNDMScheduler,
14
  UniPCMultistepScheduler,
15
  )
16
+ from diffusers.utils import logging as diffusers_logging
17
+ from transformers import logging as transformers_logging
18
 
19
  from .pipelines import CustomStableDiffusionImg2ImgPipeline, CustomStableDiffusionPipeline
20
 
21
  # improved GPU handling and progress bars; set before importing spaces
22
+ os.environ["ZEROGPU_V2"] = "1"
23
+
24
+ if find_spec("hf_transfer"):
25
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
26
+
27
+ filterwarnings("ignore", category=FutureWarning, module="diffusers")
28
+ filterwarnings("ignore", category=FutureWarning, module="transformers")
29
+
30
+ diffusers_logging.set_verbosity_error()
31
+ transformers_logging.set_verbosity_error()
32
+
33
+ _sd_files = [
34
+ "feature_extractor/preprocessor_config.json",
35
+ "safety_checker/config.json",
36
+ "scheduler/scheduler_config.json",
37
+ "text_encoder/config.json",
38
+ "text_encoder/model.fp16.safetensors",
39
+ "tokenizer/merges.txt",
40
+ "tokenizer/special_tokens_map.json",
41
+ "tokenizer/tokenizer_config.json",
42
+ "tokenizer/vocab.json",
43
+ "unet/config.json",
44
+ "unet/diffusion_pytorch_model.fp16.safetensors",
45
+ "vae/config.json",
46
+ "vae/diffusion_pytorch_model.fp16.safetensors",
47
+ "model_index.json",
48
+ ]
49
 
50
  Config = SimpleNamespace(
51
  HF_TOKEN=os.environ.get("HF_TOKEN", None),
52
  CIVIT_TOKEN=os.environ.get("CIVIT_TOKEN", None),
53
  ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
54
  HF_MODELS={
55
+ # downloaded on startup
56
+ "Lykon/dreamshaper-8": [*_sd_files],
57
+ "Comfy-Org/stable-diffusion-v1-5-archive": ["v1-5-pruned-emaonly-fp16.safetensors"],
58
+ "cyberdelia/CyberRealistic": ["CyberRealistic_V5_FP16.safetensors"],
59
+ "fluently/Fluently-v4": ["Fluently-v4.safetensors"],
60
+ "Linaqruf/anything-v3-1": ["anything-v3-2.safetensors"],
61
+ "prompthero/openjourney-v4": ["openjourney-v4.ckpt"],
62
+ "SG161222/Realistic_Vision_V5.1_noVAE": ["Realistic_Vision_V5.1_fp16-no-ema.safetensors"],
63
+ "XpucT/Deliberate": ["Deliberate_v6.safetensors"],
 
 
 
 
 
 
 
64
  },
65
  CIVIT_LORAS={
66
  # https://civitai.com/models/411088?modelVersionId=486099
lib/utils.py CHANGED
@@ -7,15 +7,22 @@ from typing import Callable, TypeVar
7
  import anyio
8
  import httpx
9
  from anyio import Semaphore
 
10
  from huggingface_hub._snapshot_download import snapshot_download
 
 
11
  from typing_extensions import ParamSpec
12
 
 
 
13
  T = TypeVar("T")
14
  P = ParamSpec("P")
15
 
16
  MAX_CONCURRENT_THREADS = 1
17
  MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
18
 
 
 
19
 
20
  @functools.lru_cache()
21
  def load_json(path: str) -> dict:
@@ -29,8 +36,21 @@ def read_file(path: str) -> str:
29
  return file.read()
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
32
  def download_repo_files(repo_id, allow_patterns, token=None):
33
- return snapshot_download(
 
 
34
  repo_id=repo_id,
35
  repo_type="model",
36
  revision="main",
@@ -38,6 +58,9 @@ def download_repo_files(repo_id, allow_patterns, token=None):
38
  allow_patterns=allow_patterns,
39
  ignore_patterns=None,
40
  )
 
 
 
41
 
42
 
43
  def download_civit_file(lora_id, version_id, file_path=".", token=None):
@@ -62,9 +85,9 @@ def download_civit_file(lora_id, version_id, file_path=".", token=None):
62
  with open(file, "wb") as f:
63
  f.write(response.content)
64
  except httpx.HTTPStatusError as e:
65
- print(f"HTTPError: {e.response.status_code} {e.response.text}")
66
  except httpx.RequestError as e:
67
- print(f"RequestError: {e}")
68
 
69
 
70
  # like the original but supports args and kwargs instead of a dict
 
7
  import anyio
8
  import httpx
9
  from anyio import Semaphore
10
+ from diffusers.utils import logging as diffusers_logging
11
  from huggingface_hub._snapshot_download import snapshot_download
12
+ from huggingface_hub.utils import are_progress_bars_disabled
13
+ from transformers import logging as transformers_logging
14
  from typing_extensions import ParamSpec
15
 
16
+ from .logger import Logger
17
+
18
  T = TypeVar("T")
19
  P = ParamSpec("P")
20
 
21
  MAX_CONCURRENT_THREADS = 1
22
  MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
23
 
24
+ log = Logger("utils")
25
+
26
 
27
  @functools.lru_cache()
28
  def load_json(path: str) -> dict:
 
36
  return file.read()
37
 
38
 
39
+ def disable_progress_bars():
40
+ transformers_logging.disable_progress_bar()
41
+ diffusers_logging.disable_progress_bar()
42
+
43
+
44
+ def enable_progress_bars():
45
+ # warns if `HF_HUB_DISABLE_PROGRESS_BARS` env var is not None
46
+ transformers_logging.enable_progress_bar()
47
+ diffusers_logging.enable_progress_bar()
48
+
49
+
50
  def download_repo_files(repo_id, allow_patterns, token=None):
51
+ was_disabled = are_progress_bars_disabled()
52
+ enable_progress_bars()
53
+ snapshot_path = snapshot_download(
54
  repo_id=repo_id,
55
  repo_type="model",
56
  revision="main",
 
58
  allow_patterns=allow_patterns,
59
  ignore_patterns=None,
60
  )
61
+ if was_disabled:
62
+ disable_progress_bars()
63
+ return snapshot_path
64
 
65
 
66
  def download_civit_file(lora_id, version_id, file_path=".", token=None):
 
85
  with open(file, "wb") as f:
86
  f.write(response.content)
87
  except httpx.HTTPStatusError as e:
88
+ log.error(f"{e.response.status_code} {e.response.text}")
89
  except httpx.RequestError as e:
90
+ log.error(f"RequestError: {e}")
91
 
92
 
93
  # like the original but supports args and kwargs instead of a dict