Spaces:
Running
Running
hysts
commited on
Commit
•
6be65e1
1
Parent(s):
0e88f89
Clean up
Browse files
model.py
CHANGED
@@ -71,22 +71,21 @@ class Model:
|
|
71 |
self.model_names = LIGHTWEIGHT_MODEL_NAMES
|
72 |
self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
|
73 |
base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
|
74 |
-
self.
|
75 |
-
base_model_path = self.model_dir / base_model_url.split('/')[-1]
|
76 |
-
self.load_base_model(base_model_path)
|
77 |
else:
|
78 |
self.model_names = ORIGINAL_MODEL_NAMES
|
79 |
self.weight_root = ORIGINAL_WEIGHT_ROOT
|
80 |
self.download_models()
|
81 |
|
82 |
-
def download_base_model(self,
|
83 |
-
model_name =
|
84 |
out_path = self.model_dir / model_name
|
85 |
-
if out_path.exists():
|
86 |
-
|
87 |
-
|
88 |
|
89 |
-
def load_base_model(self,
|
|
|
90 |
self.model.load_state_dict(load_state_dict(model_path,
|
91 |
location=self.device.type),
|
92 |
strict=False)
|
|
|
71 |
self.model_names = LIGHTWEIGHT_MODEL_NAMES
|
72 |
self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
|
73 |
base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
|
74 |
+
self.load_base_model(base_model_url)
|
|
|
|
|
75 |
else:
|
76 |
self.model_names = ORIGINAL_MODEL_NAMES
|
77 |
self.weight_root = ORIGINAL_WEIGHT_ROOT
|
78 |
self.download_models()
|
79 |
|
80 |
+
def download_base_model(self, model_url: str) -> pathlib.Path:
|
81 |
+
model_name = model_url.split('/')[-1]
|
82 |
out_path = self.model_dir / model_name
|
83 |
+
if not out_path.exists():
|
84 |
+
subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
|
85 |
+
return out_path
|
86 |
|
87 |
+
def load_base_model(self, model_url: str) -> None:
|
88 |
+
model_path = self.download_base_model(model_url)
|
89 |
self.model.load_state_dict(load_state_dict(model_path,
|
90 |
location=self.device.type),
|
91 |
strict=False)
|