Added custom models option
Browse files
model.py
CHANGED
@@ -12,6 +12,7 @@ import cv2
|
|
12 |
import einops
|
13 |
import numpy as np
|
14 |
import torch
|
|
|
15 |
from pytorch_lightning import seed_everything
|
16 |
|
17 |
sys.path.append('ControlNet')
|
@@ -28,19 +29,8 @@ from cldm.model import create_model, load_state_dict
|
|
28 |
from ldm.models.diffusion.ddim import DDIMSampler
|
29 |
from share import *
|
30 |
|
31 |
-
ORIGINAL_MODEL_NAMES = {
|
32 |
-
'canny': 'control_sd15_canny.pth',
|
33 |
-
'hough': 'control_sd15_mlsd.pth',
|
34 |
-
'hed': 'control_sd15_hed.pth',
|
35 |
-
'scribble': 'control_sd15_scribble.pth',
|
36 |
-
'pose': 'control_sd15_openpose.pth',
|
37 |
-
'seg': 'control_sd15_seg.pth',
|
38 |
-
'depth': 'control_sd15_depth.pth',
|
39 |
-
'normal': 'control_sd15_normal.pth',
|
40 |
-
}
|
41 |
-
ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/'
|
42 |
|
43 |
-
|
44 |
'canny': 'control_canny-fp16.safetensors',
|
45 |
'hough': 'control_mlsd-fp16.safetensors',
|
46 |
'hed': 'control_hed-fp16.safetensors',
|
@@ -50,36 +40,44 @@ LIGHTWEIGHT_MODEL_NAMES = {
|
|
50 |
'depth': 'control_depth-fp16.safetensors',
|
51 |
'normal': 'control_normal-fp16.safetensors',
|
52 |
}
|
53 |
-
LIGHTWEIGHT_WEIGHT_ROOT = 'https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/'
|
54 |
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
class Model:
|
57 |
def __init__(self,
|
58 |
model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
|
59 |
-
model_dir: str = 'models'
|
60 |
-
use_lightweight: bool = True):
|
61 |
self.device = torch.device(
|
62 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
63 |
self.model = create_model(model_config_path).to(self.device)
|
64 |
self.ddim_sampler = DDIMSampler(self.model)
|
65 |
self.task_name = ''
|
66 |
-
|
|
|
|
|
67 |
self.model_dir = pathlib.Path(model_dir)
|
68 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
69 |
|
70 |
-
self.use_lightweight = use_lightweight
|
71 |
-
if use_lightweight:
|
72 |
-
self.model_names = LIGHTWEIGHT_MODEL_NAMES
|
73 |
-
self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
|
74 |
-
base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
|
75 |
-
self.load_base_model(base_model_url)
|
76 |
-
else:
|
77 |
-
self.model_names = ORIGINAL_MODEL_NAMES
|
78 |
-
self.weight_root = ORIGINAL_WEIGHT_ROOT
|
79 |
-
|
80 |
self.download_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
|
|
82 |
def download_base_model(self, model_url: str) -> pathlib.Path:
|
|
|
83 |
model_name = model_url.split('/')[-1]
|
84 |
out_path = self.model_dir / model_name
|
85 |
if not out_path.exists():
|
@@ -96,27 +94,23 @@ class Model:
|
|
96 |
if task_name == self.task_name:
|
97 |
return
|
98 |
weight_path = self.get_weight_path(task_name)
|
99 |
-
|
100 |
-
self.
|
101 |
-
load_state_dict(weight_path, location=self.device))
|
102 |
-
else:
|
103 |
-
self.model.control_model.load_state_dict(
|
104 |
-
load_state_dict(weight_path, location=self.device.type))
|
105 |
self.task_name = task_name
|
106 |
|
107 |
def get_weight_path(self, task_name: str) -> str:
|
108 |
if 'scribble' in task_name:
|
109 |
task_name = 'scribble'
|
110 |
-
return f'{self.model_dir}/{
|
111 |
|
112 |
def download_models(self) -> None:
|
113 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
114 |
-
for name in
|
115 |
out_path = self.model_dir / name
|
116 |
if out_path.exists():
|
117 |
continue
|
118 |
-
|
119 |
-
|
120 |
|
121 |
@torch.inference_mode()
|
122 |
def process_canny(self, input_image, prompt, a_prompt, n_prompt,
|
@@ -763,4 +757,4 @@ class Model:
|
|
763 |
127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
764 |
|
765 |
results = [x_samples[i] for i in range(num_samples)]
|
766 |
-
return [detected_map] + results
|
|
|
12 |
import einops
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
+
from huggingface_hub import hf_hub_url
|
16 |
from pytorch_lightning import seed_everything
|
17 |
|
18 |
sys.path.append('ControlNet')
|
|
|
29 |
from ldm.models.diffusion.ddim import DDIMSampler
|
30 |
from share import *
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
MODEL_NAMES = {
|
34 |
'canny': 'control_canny-fp16.safetensors',
|
35 |
'hough': 'control_mlsd-fp16.safetensors',
|
36 |
'hed': 'control_hed-fp16.safetensors',
|
|
|
40 |
'depth': 'control_depth-fp16.safetensors',
|
41 |
'normal': 'control_normal-fp16.safetensors',
|
42 |
}
|
|
|
43 |
|
44 |
+
MODEL_REPO = 'webui/ControlNet-modules-safetensors'
|
45 |
+
|
46 |
+
DEFAULT_BASE_MODEL_REPO = 'runwayml/stable-diffusion-v1-5'
|
47 |
+
DEFAULT_BASE_MODEL_FILENAME = 'v1-5-pruned-emaonly.safetensors'
|
48 |
+
DEFAULT_BASE_MODEL_URL = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
|
49 |
|
50 |
class Model:
|
51 |
def __init__(self,
|
52 |
model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
|
53 |
+
model_dir: str = 'models'):
|
|
|
54 |
self.device = torch.device(
|
55 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
56 |
self.model = create_model(model_config_path).to(self.device)
|
57 |
self.ddim_sampler = DDIMSampler(self.model)
|
58 |
self.task_name = ''
|
59 |
+
|
60 |
+
self.base_model_url = ''
|
61 |
+
|
62 |
self.model_dir = pathlib.Path(model_dir)
|
63 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
self.download_models()
|
66 |
+
self.set_base_model(DEFAULT_BASE_MODEL_REPO,
|
67 |
+
DEFAULT_BASE_MODEL_FILENAME)
|
68 |
+
|
69 |
+
def set_base_model(self, model_id: str, filename: str) -> str:
|
70 |
+
if not model_id or not filename:
|
71 |
+
return self.base_model_url
|
72 |
+
base_model_url = hf_hub_url(model_id, filename)
|
73 |
+
if base_model_url != self.base_model_url:
|
74 |
+
self.load_base_model(base_model_url)
|
75 |
+
self.base_model_url = base_model_url
|
76 |
+
return self.base_model_url
|
77 |
|
78 |
+
|
79 |
def download_base_model(self, model_url: str) -> pathlib.Path:
|
80 |
+
self.model_dir.mkdir(exist_ok=True, parents=True)
|
81 |
model_name = model_url.split('/')[-1]
|
82 |
out_path = self.model_dir / model_name
|
83 |
if not out_path.exists():
|
|
|
94 |
if task_name == self.task_name:
|
95 |
return
|
96 |
weight_path = self.get_weight_path(task_name)
|
97 |
+
self.model.control_model.load_state_dict(
|
98 |
+
load_state_dict(weight_path, location=self.device.type))
|
|
|
|
|
|
|
|
|
99 |
self.task_name = task_name
|
100 |
|
101 |
def get_weight_path(self, task_name: str) -> str:
|
102 |
if 'scribble' in task_name:
|
103 |
task_name = 'scribble'
|
104 |
+
return f'{self.model_dir}/{MODEL_NAMES[task_name]}'
|
105 |
|
106 |
def download_models(self) -> None:
|
107 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
108 |
+
for name in MODEL_NAMES.values():
|
109 |
out_path = self.model_dir / name
|
110 |
if out_path.exists():
|
111 |
continue
|
112 |
+
model_url = hf_hub_url(MODEL_REPO, name)
|
113 |
+
subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
|
114 |
|
115 |
@torch.inference_mode()
|
116 |
def process_canny(self, input_image, prompt, a_prompt, n_prompt,
|
|
|
757 |
127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
758 |
|
759 |
results = [x_samples[i] for i in range(num_samples)]
|
760 |
+
return [detected_map] + results
|