John6666 commited on
Commit
d83c4c9
·
verified ·
1 Parent(s): c7c9dd3

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. stkey.py +1 -1
  3. utils.py +109 -31
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🐶
4
  colorFrom: yellow
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.0.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: yellow
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.1.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
stkey.py CHANGED
@@ -7,7 +7,7 @@ from safetensors.torch import load_file, save_file
7
  import torch
8
 
9
 
10
- SDXL_KEYS_FILE = "sdxl_keys.txt"
11
 
12
 
13
  def list_uniq(l):
 
7
  import torch
8
 
9
 
10
+ SDXL_KEYS_FILE = "keys/sdxl_keys.txt"
11
 
12
 
13
  def list_uniq(l):
utils.py CHANGED
@@ -6,6 +6,8 @@ import shutil
6
  import gc
7
  import re
8
  import urllib.parse
 
 
9
 
10
 
11
  def get_token():
@@ -72,15 +74,24 @@ def list_sub(a, b):
72
 
73
 
74
  def is_repo_name(s):
75
- return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
 
 
 
 
 
 
 
76
 
77
 
78
  def split_hf_url(url: str):
79
  try:
80
- s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0])
81
  if len(s) < 4: return "", "", "", ""
82
  repo_id = s[1]
83
- repo_type = "dataset" if s[0] == "datasets" else "model"
 
 
84
  subfolder = urllib.parse.unquote(s[2]) if s[2] else None
85
  filename = urllib.parse.unquote(s[3])
86
  return repo_id, filename, subfolder, repo_type
@@ -92,14 +103,16 @@ def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
92
  hf_token = get_token()
93
  repo_id, filename, subfolder, repo_type = split_hf_url(url)
94
  try:
95
- if subfolder is not None: hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
96
- else: hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
 
 
97
  except Exception as e:
98
  print(f"Failed to download: {e}")
 
99
 
100
 
101
  def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
102
- hf_token = get_token()
103
  url = url.strip()
104
  if "drive.google.com" in url:
105
  original_dir = os.getcwd()
@@ -108,14 +121,8 @@ def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(trac
108
  os.chdir(original_dir)
109
  elif "huggingface.co" in url:
110
  url = url.replace("?download=true", "")
111
- if "/blob/" in url:
112
- url = url.replace("/blob/", "/resolve/")
113
- #user_header = f'"Authorization: Bearer {hf_token}"'
114
- if True or hf_token:
115
- download_hf_file(directory, url)
116
- #os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
117
- else:
118
- os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
119
  elif "civitai.com" in url:
120
  if "?" in url:
121
  url = url.split("?")[0]
@@ -128,14 +135,13 @@ def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(trac
128
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
129
 
130
 
131
- def get_local_model_list(dir_path):
132
- model_list = []
133
- valid_extensions = ('.safetensors', '.fp16.safetensors', '.sft')
134
  for file in Path(dir_path).glob("**/*.*"):
135
- if file.is_file() and file.suffix in valid_extensions:
136
  file_path = str(file)
137
- model_list.append(file_path)
138
- return model_list
139
 
140
 
141
  def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
@@ -150,13 +156,13 @@ def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqd
150
  new_file = f"{temp_dir}/{url.split('/')[-1]}"
151
  else:
152
  print(f"Start downloading: {url}")
153
- before = get_local_model_list(temp_dir)
154
  try:
155
  download_thing(temp_dir, url.strip(), civitai_key)
156
  except Exception:
157
  print(f"Download failed: {url}")
158
  return ""
159
- after = get_local_model_list(temp_dir)
160
  new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
161
  if not new_file:
162
  print(f"Download failed: {url}")
@@ -165,8 +171,7 @@ def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqd
165
  return new_file
166
 
167
 
168
- # https://huggingface.co/docs/huggingface_hub/v0.25.1/en/package_reference/file_download#huggingface_hub.snapshot_download
169
- def download_repo(repo_id, dir_path, progress=gr.Progress(track_tqdm=True)):
170
  hf_token = get_token()
171
  try:
172
  snapshot_download(repo_id=repo_id, local_dir=dir_path, token=hf_token, allow_patterns=["*.safetensors", "*.bin"],
@@ -178,20 +183,93 @@ def download_repo(repo_id, dir_path, progress=gr.Progress(track_tqdm=True)):
178
  return False
179
 
180
 
181
- def upload_repo(new_repo_id, dir_path, is_private, progress=gr.Progress(track_tqdm=True)):
182
  hf_token = get_token()
183
  api = HfApi(token=hf_token)
184
  try:
185
  progress(0, desc="Start uploading...")
186
- api.create_repo(repo_id=new_repo_id, token=hf_token, private=is_private, exist_ok=True)
187
  for path in Path(dir_path).glob("*"):
188
  if path.is_dir():
189
- api.upload_folder(repo_id=new_repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token)
190
  elif path.is_file():
191
- api.upload_file(repo_id=new_repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token)
192
  progress(1, desc="Uploaded.")
193
- url = f"https://huggingface.co/{new_repo_id}"
194
  except Exception as e:
195
- print(f"Error: Failed to upload to {new_repo_id}. {e}")
196
  return ""
197
- return url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import gc
7
  import re
8
  import urllib.parse
9
+ import subprocess
10
+ import time
11
 
12
 
13
  def get_token():
 
74
 
75
 
76
  def is_repo_name(s):
77
+ return re.fullmatch(r'^[\w_\-\.]+/[\w_\-\.]+$', s)
78
+
79
+
80
+ def get_hf_url(repo_id: str, repo_type: str="model"):
81
+ if repo_type == "dataset": url = f"https://huggingface.co/datasets/{repo_id}"
82
+ elif repo_type == "space": url = f"https://huggingface.co/spaces/{repo_id}"
83
+ else: url = f"https://huggingface.co/{repo_id}"
84
+ return url
85
 
86
 
87
  def split_hf_url(url: str):
88
  try:
89
+ s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets|spaces)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0])
90
  if len(s) < 4: return "", "", "", ""
91
  repo_id = s[1]
92
+ if s[0] == "datasets": repo_type = "dataset"
93
+ elif s[0] == "spaces": repo_type = "space"
94
+ else: repo_type = "model"
95
  subfolder = urllib.parse.unquote(s[2]) if s[2] else None
96
  filename = urllib.parse.unquote(s[3])
97
  return repo_id, filename, subfolder, repo_type
 
103
  hf_token = get_token()
104
  repo_id, filename, subfolder, repo_type = split_hf_url(url)
105
  try:
106
+ print(f"Downloading {url} to {directory}")
107
+ if subfolder is not None: path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
108
+ else: path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
109
+ return path
110
  except Exception as e:
111
  print(f"Failed to download: {e}")
112
+ return None
113
 
114
 
115
  def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
 
116
  url = url.strip()
117
  if "drive.google.com" in url:
118
  original_dir = os.getcwd()
 
121
  os.chdir(original_dir)
122
  elif "huggingface.co" in url:
123
  url = url.replace("?download=true", "")
124
+ if "/blob/" in url: url = url.replace("/blob/", "/resolve/")
125
+ download_hf_file(directory, url)
 
 
 
 
 
 
126
  elif "civitai.com" in url:
127
  if "?" in url:
128
  url = url.split("?")[0]
 
135
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
136
 
137
 
138
+ def get_local_file_list(dir_path):
139
+ file_list = []
 
140
  for file in Path(dir_path).glob("**/*.*"):
141
+ if file.is_file():
142
  file_path = str(file)
143
+ file_list.append(file_path)
144
+ return file_list
145
 
146
 
147
  def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
 
156
  new_file = f"{temp_dir}/{url.split('/')[-1]}"
157
  else:
158
  print(f"Start downloading: {url}")
159
+ before = get_local_file_list(temp_dir)
160
  try:
161
  download_thing(temp_dir, url.strip(), civitai_key)
162
  except Exception:
163
  print(f"Download failed: {url}")
164
  return ""
165
+ after = get_local_file_list(temp_dir)
166
  new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
167
  if not new_file:
168
  print(f"Download failed: {url}")
 
171
  return new_file
172
 
173
 
174
+ def download_repo(repo_id: str, dir_path: str, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
 
175
  hf_token = get_token()
176
  try:
177
  snapshot_download(repo_id=repo_id, local_dir=dir_path, token=hf_token, allow_patterns=["*.safetensors", "*.bin"],
 
183
  return False
184
 
185
 
186
+ def upload_repo(repo_id: str, dir_path: str, is_private: bool, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
187
  hf_token = get_token()
188
  api = HfApi(token=hf_token)
189
  try:
190
  progress(0, desc="Start uploading...")
191
+ api.create_repo(repo_id=repo_id, token=hf_token, private=is_private, exist_ok=True)
192
  for path in Path(dir_path).glob("*"):
193
  if path.is_dir():
194
+ api.upload_folder(repo_id=repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token)
195
  elif path.is_file():
196
+ api.upload_file(repo_id=repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token)
197
  progress(1, desc="Uploaded.")
198
+ return get_hf_url(repo_id, "model")
199
  except Exception as e:
200
+ print(f"Error: Failed to upload to {repo_id}. {e}")
201
  return ""
202
+
203
+
204
+ HF_SUBFOLDER_NAME = ["None", "user_repo"]
205
+
206
+
207
+ def duplicate_hf_repo(src_repo: str, dst_repo: str, src_repo_type: str, dst_repo_type: str,
208
+ is_private: bool, subfolder_type: str=HF_SUBFOLDER_NAME[1], progress=gr.Progress(track_tqdm=True)):
209
+ hf_token = get_token()
210
+ api = HfApi(token=hf_token)
211
+ try:
212
+ if subfolder_type == "user_repo": subfolder = src_repo.replace("/", "_")
213
+ else: subfolder = ""
214
+ progress(0, desc="Start duplicating...")
215
+ api.create_repo(repo_id=dst_repo, repo_type=dst_repo_type, private=is_private, exist_ok=True, token=hf_token)
216
+ for path in api.list_repo_files(repo_id=src_repo, repo_type=src_repo_type, token=hf_token):
217
+ file = hf_hub_download(repo_id=src_repo, filename=path, repo_type=src_repo_type, token=hf_token)
218
+ if not Path(file).exists(): continue
219
+ if Path(file).is_dir(): # unused for now
220
+ api.upload_folder(repo_id=dst_repo, folder_path=file, path_in_repo=f"{subfolder}/{path}" if subfolder else path,
221
+ repo_type=dst_repo_type, token=hf_token)
222
+ elif Path(file).is_file():
223
+ api.upload_file(repo_id=dst_repo, path_or_fileobj=file, path_in_repo=f"{subfolder}/{path}" if subfolder else path,
224
+ repo_type=dst_repo_type, token=hf_token)
225
+ if Path(file).exists(): Path(file).unlink()
226
+ progress(1, desc="Duplicated.")
227
+ return f"{get_hf_url(dst_repo, dst_repo_type)}/tree/main/{subfolder}" if subfolder else get_hf_url(dst_repo, dst_repo_type)
228
+ except Exception as e:
229
+ print(f"Error: Failed to duplicate repo {src_repo} to {dst_repo}. {e}")
230
+ return ""
231
+
232
+
233
+ BASE_DIR = str(Path(__file__).resolve().parent.resolve())
234
+ CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
235
+
236
+
237
+ def get_file(url: str, path: str): # requires aria2, gdown
238
+ print(f"Downloading {url} to {path}...")
239
+ get_download_file(path, url, CIVITAI_API_KEY)
240
+
241
+
242
+ def git_clone(url: str, path: str, pip: bool=False, addcmd: str=""): # requires git
243
+ os.makedirs(str(Path(BASE_DIR, path)), exist_ok=True)
244
+ os.chdir(Path(BASE_DIR, path))
245
+ print(f"Cloning {url} to {path}...")
246
+ cmd = f'git clone {url}'
247
+ print(f'Running {cmd} at {Path.cwd()}')
248
+ i = subprocess.run(cmd, shell=True).returncode
249
+ if i != 0: print(f'Error occured at running {cmd}')
250
+ p = url.split("/")[-1]
251
+ if not Path(p).exists: return
252
+ if pip:
253
+ os.chdir(Path(BASE_DIR, path, p))
254
+ cmd = f'pip install -r requirements.txt'
255
+ print(f'Running {cmd} at {Path.cwd()}')
256
+ i = subprocess.run(cmd, shell=True).returncode
257
+ if i != 0: print(f'Error occured at running {cmd}')
258
+ if addcmd:
259
+ os.chdir(Path(BASE_DIR, path, p))
260
+ cmd = addcmd
261
+ print(f'Running {cmd} at {Path.cwd()}')
262
+ i = subprocess.run(cmd, shell=True).returncode
263
+ if i != 0: print(f'Error occured at running {cmd}')
264
+
265
+
266
+ def run(cmd: str, timeout: float=0):
267
+ print(f'Running {cmd} at {Path.cwd()}')
268
+ if timeout == 0:
269
+ i = subprocess.run(cmd, shell=True).returncode
270
+ if i != 0: print(f'Error occured at running {cmd}')
271
+ else:
272
+ p = subprocess.Popen(cmd, shell=True)
273
+ time.sleep(timeout)
274
+ p.terminate()
275
+ print(f'Terminated in {timeout} seconds')