hysts HF staff commited on
Commit
64bc782
·
1 Parent(s): 92f1008
Files changed (1) hide show
  1. model.py +43 -5
model.py CHANGED
@@ -40,11 +40,24 @@ ORIGINAL_MODEL_NAMES = {
40
  }
41
  ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/'
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  class Model:
45
  def __init__(self,
46
  model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
47
- model_dir: str = 'models'):
 
48
  self.device = torch.device(
49
  'cuda:0' if torch.cuda.is_available() else 'cpu')
50
  self.model = create_model(model_config_path).to(self.device)
@@ -54,16 +67,41 @@ class Model:
54
  self.model_dir = pathlib.Path(model_dir)
55
  self.model_dir.mkdir(exist_ok=True, parents=True)
56
 
57
- self.model_names = ORIGINAL_MODEL_NAMES
58
- self.weight_root = ORIGINAL_WEIGHT_ROOT
 
 
 
 
 
 
 
 
59
  self.download_models()
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def load_weight(self, task_name: str) -> None:
62
  if task_name == self.task_name:
63
  return
64
  weight_path = self.get_weight_path(task_name)
65
- self.model.load_state_dict(
66
- load_state_dict(weight_path, location=self.device))
 
 
 
 
67
  self.task_name = task_name
68
 
69
  def get_weight_path(self, task_name: str) -> str:
 
40
  }
41
  ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/'
42
 
43
+ LIGHTWEIGHT_MODEL_NAMES = {
44
+ 'canny': 'control_canny-fp16.safetensors',
45
+ 'hough': 'control_mlsd-fp16.safetensors',
46
+ 'hed': 'control_hed-fp16.safetensors',
47
+ 'scribble': 'control_scribble-fp16.safetensors',
48
+ 'pose': 'control_openpose-fp16.safetensors',
49
+ 'seg': 'control_seg-fp16.safetensors',
50
+ 'depth': 'control_depth-fp16.safetensors',
51
+ 'normal': 'control_normal-fp16.safetensors',
52
+ }
53
+ LIGHTWEIGHT_WEIGHT_ROOT = 'https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/'
54
+
55
 
56
  class Model:
57
  def __init__(self,
58
  model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
59
+ model_dir: str = 'models',
60
+ use_lightweight: bool = True):
61
  self.device = torch.device(
62
  'cuda:0' if torch.cuda.is_available() else 'cpu')
63
  self.model = create_model(model_config_path).to(self.device)
 
67
  self.model_dir = pathlib.Path(model_dir)
68
  self.model_dir.mkdir(exist_ok=True, parents=True)
69
 
70
+ self.use_lightweight = use_lightweight
71
+ if use_lightweight:
72
+ self.model_names = LIGHTWEIGHT_MODEL_NAMES
73
+ self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
74
+ base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
75
+ self.load_base_model(base_model_url)
76
+ else:
77
+ self.model_names = ORIGINAL_MODEL_NAMES
78
+ self.weight_root = ORIGINAL_WEIGHT_ROOT
79
+
80
  self.download_models()
81
 
82
+ def download_base_model(self, model_url: str) -> pathlib.Path:
83
+ model_name = model_url.split('/')[-1]
84
+ out_path = self.model_dir / model_name
85
+ if not out_path.exists():
86
+ subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
87
+ return out_path
88
+
89
+ def load_base_model(self, model_url: str) -> None:
90
+ model_path = self.download_base_model(model_url)
91
+ self.model.load_state_dict(load_state_dict(model_path,
92
+ location=self.device.type),
93
+ strict=False)
94
+
95
  def load_weight(self, task_name: str) -> None:
96
  if task_name == self.task_name:
97
  return
98
  weight_path = self.get_weight_path(task_name)
99
+ if not self.use_lightweight:
100
+ self.model.load_state_dict(
101
+ load_state_dict(weight_path, location=self.device))
102
+ else:
103
+ self.model.control_model.load_state_dict(
104
+ load_state_dict(weight_path, location=self.device.type))
105
  self.task_name = task_name
106
 
107
  def get_weight_path(self, task_name: str) -> str: