John6666 commited on
Commit
d2ce16f
1 Parent(s): 7f70a45

Upload convert_url_to_diffusers_sdxl_gr.py

Browse files
Files changed (1) hide show
  1. convert_url_to_diffusers_sdxl_gr.py +65 -169
convert_url_to_diffusers_sdxl_gr.py CHANGED
@@ -1,3 +1,13 @@
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  from pathlib import Path
3
  import os
@@ -5,108 +15,30 @@ import torch
5
  from diffusers import StableDiffusionXLPipeline, AutoencoderKL
6
  from transformers import CLIPTokenizer, CLIPTextModel
7
  import gradio as gr
8
- from huggingface_hub import hf_hub_download, HfApi
9
- import urllib.parse
10
- import re
11
  import shutil
12
  import gc
13
  # also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning
 
14
 
15
 
16
- def list_sub(a, b):
17
- return [e for e in a if e not in b]
18
-
19
-
20
- def is_repo_name(s):
21
- return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
22
-
23
-
24
- def split_hf_url(url: str):
25
- try:
26
- s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0])
27
- if len(s) < 4: return "", "", "", ""
28
- repo_id = s[1]
29
- repo_type = "dataset" if s[0] == "datasets" else "model"
30
- subfolder = urllib.parse.unquote(s[2]) if s[2] else None
31
- filename = urllib.parse.unquote(s[3])
32
- return repo_id, filename, subfolder, repo_type
33
- except Exception as e:
34
- print(e)
35
-
36
-
37
- def download_hf_file(directory, url, hf_token="", progress=gr.Progress(track_tqdm=True)):
38
- repo_id, filename, subfolder, repo_type = split_hf_url(url)
39
- try:
40
- if subfolder is not None: hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
41
- else: hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
42
- except Exception as e:
43
- print(f"Failed to download: {e}")
44
-
45
-
46
- def download_thing(directory, url, civitai_api_key="", hf_token="", progress=gr.Progress(track_tqdm=True)):
47
- url = url.strip()
48
- if "drive.google.com" in url:
49
- original_dir = os.getcwd()
50
- os.chdir(directory)
51
- os.system(f"gdown --fuzzy {url}")
52
- os.chdir(original_dir)
53
- elif "huggingface.co" in url:
54
- url = url.replace("?download=true", "")
55
- if "/blob/" in url:
56
- url = url.replace("/blob/", "/resolve/")
57
- user_header = f'"Authorization: Bearer {hf_token}"'
58
- if hf_token:
59
- download_hf_file(directory, url, hf_token)
60
- #os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
61
- else:
62
- os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
63
- elif "civitai.com" in url:
64
- if "?" in url:
65
- url = url.split("?")[0]
66
- if civitai_api_key:
67
- url = url + f"?token={civitai_api_key}"
68
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
69
- else:
70
- print("You need an API key to download Civitai models.")
71
- else:
72
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
73
-
74
-
75
- def get_local_model_list(dir_path):
76
- model_list = []
77
- valid_extensions = ('.safetensors')
78
- for file in Path(dir_path).glob("**/*.*"):
79
- if file.is_file() and file.suffix in valid_extensions:
80
- file_path = str(file)
81
- model_list.append(file_path)
82
- return model_list
83
-
84
-
85
- def get_download_file(temp_dir, url, civitai_key, hf_token, progress=gr.Progress(track_tqdm=True)):
86
- if not "http" in url and is_repo_name(url) and not Path(url).exists():
87
- print(f"Use HF Repo: {url}")
88
- new_file = url
89
- elif not "http" in url and Path(url).exists():
90
- print(f"Use local file: {url}")
91
- new_file = url
92
- elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
93
- print(f"File to download alreday exists: {url}")
94
- new_file = f"{temp_dir}/{url.split('/')[-1]}"
95
- else:
96
- print(f"Start downloading: {url}")
97
- before = get_local_model_list(temp_dir)
98
- try:
99
- download_thing(temp_dir, url.strip(), civitai_key, hf_token)
100
- except Exception:
101
- print(f"Download failed: {url}")
102
- return ""
103
- after = get_local_model_list(temp_dir)
104
- new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
105
- if not new_file:
106
- print(f"Download failed: {url}")
107
- return ""
108
- print(f"Download completed: {url}")
109
- return new_file
110
 
111
 
112
  from diffusers import (
@@ -224,19 +156,19 @@ tags:
224
  f.write(md)
225
 
226
 
227
- def fuse_loras(pipe, lora_dict={}, temp_dir=".", civitai_key="", hf_token=""):
228
  if not lora_dict or not isinstance(lora_dict, dict): return pipe
229
  a_list = []
230
  w_list = []
231
  for k, v in lora_dict.items():
232
  if not k: continue
233
- new_lora_file = get_download_file(temp_dir, k, civitai_key, hf_token)
234
  if not new_lora_file or not Path(new_lora_file).exists():
235
  print(f"LoRA not found: {k}")
236
  continue
237
  w_name = Path(new_lora_file).name
238
  a_name = Path(new_lora_file).stem
239
- pipe.load_lora_weights(new_lora_file, weight_name = w_name, adapter_name = a_name)
240
  a_list.append(a_name)
241
  w_list.append(v)
242
  if not a_list: return pipe
@@ -246,121 +178,84 @@ def fuse_loras(pipe, lora_dict={}, temp_dir=".", civitai_key="", hf_token=""):
246
  return pipe
247
 
248
 
249
- def convert_url_to_diffusers_sdxl(url, civitai_key="", hf_token="", is_upload_sf=False, dtype="fp16", vae="",
250
- scheduler="Euler a", lora_dict={}, is_local=True, clip="", progress=gr.Progress(track_tqdm=True)):
251
  progress(0, desc="Start converting...")
252
- temp_dir = "."
253
- new_file = get_download_file(temp_dir, url, civitai_key, hf_token)
254
  if not new_file:
255
  print(f"Not found: {url}")
256
  return ""
257
- new_repo_name = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") #
258
 
259
- type_kwargs = {}
260
  kwargs = {}
261
- if dtype == "fp16": type_kwargs["torch_dtype"] = torch.float16
262
- elif dtype == "fp32": type_kwargs["torch_dtype"] = torch.float32
263
- elif dtype == "bf16": type_kwargs["torch_dtype"] = torch.bfloat16
264
- elif dtype == "fp8": type_kwargs["torch_dtype"] = torch.float8_e4m3fn
265
 
266
  new_vae_file = ""
267
  if vae:
268
  if is_repo_name(vae): my_vae = AutoencoderKL.from_pretrained(vae, **type_kwargs)
269
  else:
270
- new_vae_file = get_download_file(temp_dir, vae, civitai_key, hf_token)
271
- if new_vae_file: my_vae = AutoencoderKL.from_single_file(new_vae_file, **type_kwargs)
272
- kwargs["vae"] = my_vae
273
 
274
  if clip:
275
  my_tokenizer = CLIPTokenizer.from_pretrained(clip)
 
276
  my_text_encoder = CLIPTextModel.from_pretrained(clip, **type_kwargs)
277
- kwargs["tokenizer"] = my_tokenizer
278
- kwargs["text_encoder"] = my_text_encoder
279
 
280
  pipe = None
281
  if is_repo_name(url): pipe = StableDiffusionXLPipeline.from_pretrained(new_file, use_safetensors=True, **kwargs, **type_kwargs)
282
  else: pipe = StableDiffusionXLPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **type_kwargs)
283
 
284
- pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, hf_token)
285
 
286
  sconf = get_scheduler_config(scheduler)
287
  pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1])
288
 
289
- pipe.save_pretrained(new_repo_name, safe_serialization=True, use_safetensors=True)
290
 
291
- if Path(new_repo_name).exists(): save_readme_md(new_repo_name, url)
292
 
293
  if not is_local:
294
- if not is_repo_name(new_file) and is_upload_sf: shutil.move(str(Path(new_file).resolve()), str(Path(new_repo_name, Path(new_file).name).resolve()))
295
  else: os.remove(new_file)
296
  del pipe
297
  torch.cuda.empty_cache()
298
  gc.collect()
299
 
300
  progress(1, desc="Converted.")
301
- return new_repo_name
302
-
303
-
304
- def is_repo_exists(repo_id, hf_token):
305
- api = HfApi(token=hf_token)
306
- try:
307
- if api.repo_exists(repo_id=repo_id): return True
308
- else: return False
309
- except Exception as e:
310
- print(f"Error: Failed to connect {repo_id}. {e}")
311
- return True # for safe
312
-
313
-
314
- def create_diffusers_repo(new_repo_id, diffusers_folder, is_private, hf_token, progress=gr.Progress(track_tqdm=True)):
315
- api = HfApi(token=hf_token)
316
- try:
317
- progress(0, desc="Start uploading...")
318
- api.create_repo(repo_id=new_repo_id, token=hf_token, private=is_private, exist_ok=True)
319
- for path in Path(diffusers_folder).glob("*"):
320
- if path.is_dir():
321
- api.upload_folder(repo_id=new_repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token)
322
- elif path.is_file():
323
- api.upload_file(repo_id=new_repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token)
324
- progress(1, desc="Uploaded.")
325
- url = f"https://huggingface.co/{new_repo_id}"
326
- except Exception as e:
327
- print(f"Error: Failed to upload to {new_repo_id}. {e}")
328
- return ""
329
- return url
330
 
331
 
332
  def convert_url_to_diffusers_repo(dl_url, hf_user, hf_repo, hf_token, civitai_key="", is_private=True, is_overwrite=False, is_upload_sf=False,
333
- repo_urls=[], dtype="fp16", vae=None, clip="", scheduler="Euler a",
334
  lora1=None, lora1s=1.0, lora2=None, lora2s=1.0, lora3=None, lora3s=1.0,
335
  lora4=None, lora4s=1.0, lora5=None, lora5s=1.0, progress=gr.Progress(track_tqdm=True)):
 
336
  if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
337
  if not hf_token and os.environ.get("HF_TOKEN"): hf_token = os.environ.get("HF_TOKEN") # default HF write token
338
  if not hf_user and os.environ.get("HF_USER"): hf_user = os.environ.get("HF_USER") # default username
339
- if not hf_user:
340
- print(f"Invalid user name: {hf_user}")
341
- progress(1, desc=f"Invalid user name: {hf_user}")
342
- return gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True)
343
  lora_dict = {lora1: lora1s, lora2: lora2s, lora3: lora3s, lora4: lora4s, lora5: lora5s}
344
- new_path = convert_url_to_diffusers_sdxl(dl_url, civitai_key, hf_token, is_upload_sf, dtype, vae, scheduler, lora_dict, False, clip)
345
  if not new_path: return ""
346
  new_repo_id = f"{hf_user}/{Path(new_path).stem}"
347
  if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}"
348
- if not is_repo_name(new_repo_id):
349
- print(f"Invalid repo name: {new_repo_id}")
350
- progress(1, desc=f"Invalid repo name: {new_repo_id}")
351
- return gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True)
352
- if not is_overwrite and is_repo_exists(new_repo_id, hf_token):
353
- print(f"Repo already exists: {new_repo_id}")
354
- progress(1, desc=f"Repo already exists: {new_repo_id}")
355
- return gr.update(value=repo_urls, choices=repo_urls), gr.update(visible=True)
356
- repo_url = create_diffusers_repo(new_repo_id, new_path, is_private, hf_token)
357
  shutil.rmtree(new_path)
358
- if not repo_urls: repo_urls = []
359
- repo_urls.append(repo_url)
360
  md = "### Your new repo:\n"
361
- for u in repo_urls:
362
  md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
363
- return gr.update(value=repo_urls, choices=repo_urls), gr.update(value=md)
364
 
365
 
366
  if __name__ == "__main__":
@@ -386,10 +281,11 @@ if __name__ == "__main__":
386
  args = parser.parse_args()
387
  assert args.url is not None, "Must provide a URL!"
388
 
 
389
  lora_dict = {args.lora1: args.lora1s, args.lora2: args.lora2s, args.lora3: args.lora3s, args.lora4: args.lora4s, args.lora5: args.lora5s}
390
-
391
  if args.loras and Path(args.loras).exists():
392
  for p in Path(args.loras).glob('**/*.safetensors'):
393
  lora_dict[str(p)] = 1.0
 
394
 
395
- convert_url_to_diffusers_sdxl(args.url, args.civitai_key, args.dtype, args.vae, args.scheduler, lora_dict, True)
 
1
+ import os
2
+ if os.environ.get("SPACES_ZERO_GPU") is not None:
3
+ import spaces
4
+ else:
5
+ class spaces:
6
+ @staticmethod
7
+ def GPU(func):
8
+ def wrapper(*args, **kwargs):
9
+ return func(*args, **kwargs)
10
+ return wrapper
11
  import argparse
12
  from pathlib import Path
13
  import os
 
15
  from diffusers import StableDiffusionXLPipeline, AutoencoderKL
16
  from transformers import CLIPTokenizer, CLIPTextModel
17
  import gradio as gr
 
 
 
18
  import shutil
19
  import gc
20
  # also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning
21
+ from utils import (set_token, is_repo_exists, is_repo_name, get_download_file, upload_repo)
22
 
23
 
24
+ @spaces.GPU
25
+ def fake_gpu():
26
+ pass
27
+
28
+
29
+ TEMP_DIR = "."
30
+
31
+
32
+ DTYPE_DICT = {
33
+ "fp16": torch.float16,
34
+ "bf16": torch.bfloat16,
35
+ "fp32": torch.float32,
36
+ "fp8": torch.float8_e4m3fn
37
+ }
38
+
39
+
40
+ def get_dtype(dtype: str):
41
+ return DTYPE_DICT.get(dtype, torch.float16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
 
44
  from diffusers import (
 
156
  f.write(md)
157
 
158
 
159
+ def fuse_loras(pipe, lora_dict={}, temp_dir=TEMP_DIR, civitai_key=""):
160
  if not lora_dict or not isinstance(lora_dict, dict): return pipe
161
  a_list = []
162
  w_list = []
163
  for k, v in lora_dict.items():
164
  if not k: continue
165
+ new_lora_file = get_download_file(temp_dir, k, civitai_key)
166
  if not new_lora_file or not Path(new_lora_file).exists():
167
  print(f"LoRA not found: {k}")
168
  continue
169
  w_name = Path(new_lora_file).name
170
  a_name = Path(new_lora_file).stem
171
+ pipe.load_lora_weights(new_lora_file, weight_name=w_name, adapter_name=a_name)
172
  a_list.append(a_name)
173
  w_list.append(v)
174
  if not a_list: return pipe
 
178
  return pipe
179
 
180
 
181
+ def convert_url_to_diffusers_sdxl(url, civitai_key="", is_upload_sf=False, dtype="fp16", vae="", clip="",
182
+ scheduler="Euler a", lora_dict={}, is_local=True, progress=gr.Progress(track_tqdm=True)):
183
  progress(0, desc="Start converting...")
184
+ temp_dir = TEMP_DIR
185
+ new_file = get_download_file(temp_dir, url, civitai_key)
186
  if not new_file:
187
  print(f"Not found: {url}")
188
  return ""
189
+ new_dir = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") #
190
 
 
191
  kwargs = {}
192
+ type_kwargs = {}
193
+ if dtype != "default": type_kwargs["torch_dtype"] = get_dtype(dtype)
 
 
194
 
195
  new_vae_file = ""
196
  if vae:
197
  if is_repo_name(vae): my_vae = AutoencoderKL.from_pretrained(vae, **type_kwargs)
198
  else:
199
+ new_vae_file = get_download_file(temp_dir, vae, civitai_key)
200
+ my_vae = AutoencoderKL.from_single_file(new_vae_file, **type_kwargs) if new_vae_file else None
201
+ if my_vae: kwargs["vae"] = my_vae
202
 
203
  if clip:
204
  my_tokenizer = CLIPTokenizer.from_pretrained(clip)
205
+ if my_tokenizer: kwargs["tokenizer"] = my_tokenizer
206
  my_text_encoder = CLIPTextModel.from_pretrained(clip, **type_kwargs)
207
+ if my_text_encoder: kwargs["text_encoder"] = my_text_encoder
 
208
 
209
  pipe = None
210
  if is_repo_name(url): pipe = StableDiffusionXLPipeline.from_pretrained(new_file, use_safetensors=True, **kwargs, **type_kwargs)
211
  else: pipe = StableDiffusionXLPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **type_kwargs)
212
 
213
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key)
214
 
215
  sconf = get_scheduler_config(scheduler)
216
  pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1])
217
 
218
+ pipe.save_pretrained(new_dir, safe_serialization=True, use_safetensors=True)
219
 
220
+ if Path(new_dir).exists(): save_readme_md(new_dir, url)
221
 
222
  if not is_local:
223
+ if not is_repo_name(new_file) and is_upload_sf: shutil.move(str(Path(new_file).resolve()), str(Path(new_dir, Path(new_file).name).resolve()))
224
  else: os.remove(new_file)
225
  del pipe
226
  torch.cuda.empty_cache()
227
  gc.collect()
228
 
229
  progress(1, desc="Converted.")
230
+ return new_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
 
233
  def convert_url_to_diffusers_repo(dl_url, hf_user, hf_repo, hf_token, civitai_key="", is_private=True, is_overwrite=False, is_upload_sf=False,
234
+ urls=[], dtype="fp16", vae="", clip="", scheduler="Euler a",
235
  lora1=None, lora1s=1.0, lora2=None, lora2s=1.0, lora3=None, lora3s=1.0,
236
  lora4=None, lora4s=1.0, lora5=None, lora5s=1.0, progress=gr.Progress(track_tqdm=True)):
237
+ is_local = False
238
  if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
239
  if not hf_token and os.environ.get("HF_TOKEN"): hf_token = os.environ.get("HF_TOKEN") # default HF write token
240
  if not hf_user and os.environ.get("HF_USER"): hf_user = os.environ.get("HF_USER") # default username
241
+ if not hf_user: raise gr.Error(f"Invalid user name: {hf_user}")
242
+ if not hf_repo and os.environ.get("HF_REPO"): hf_repo = os.environ.get("HF_REPO") # default reponame
243
+ set_token(hf_token)
 
244
  lora_dict = {lora1: lora1s, lora2: lora2s, lora3: lora3s, lora4: lora4s, lora5: lora5s}
245
+ new_path = convert_url_to_diffusers_sdxl(dl_url, civitai_key, is_upload_sf, dtype, vae, clip, scheduler, lora_dict, is_local)
246
  if not new_path: return ""
247
  new_repo_id = f"{hf_user}/{Path(new_path).stem}"
248
  if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}"
249
+ if not is_repo_name(new_repo_id): raise gr.Error(f"Invalid repo name: {new_repo_id}")
250
+ if not is_overwrite and is_repo_exists(new_repo_id): raise gr.Error(f"Repo already exists: {new_repo_id}")
251
+ repo_url = upload_repo(new_repo_id, new_path, is_private)
 
 
 
 
 
 
252
  shutil.rmtree(new_path)
253
+ if not urls: urls = []
254
+ urls.append(repo_url)
255
  md = "### Your new repo:\n"
256
+ for u in urls:
257
  md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
258
+ return gr.update(value=urls, choices=urls), gr.update(value=md)
259
 
260
 
261
  if __name__ == "__main__":
 
281
  args = parser.parse_args()
282
  assert args.url is not None, "Must provide a URL!"
283
 
284
+ is_local = True
285
  lora_dict = {args.lora1: args.lora1s, args.lora2: args.lora2s, args.lora3: args.lora3s, args.lora4: args.lora4s, args.lora5: args.lora5s}
 
286
  if args.loras and Path(args.loras).exists():
287
  for p in Path(args.loras).glob('**/*.safetensors'):
288
  lora_dict[str(p)] = 1.0
289
+ clip = ""
290
 
291
+ convert_url_to_diffusers_sdxl(args.url, args.civitai_key, args.dtype, args.vae, clip, args.scheduler, lora_dict, is_local)