John6666's picture
Upload 8 files
7847c2d verified
raw
history blame
4.1 kB
import gradio as gr
from huggingface_hub import HfApi, hf_hub_url
import os
from pathlib import Path
import gc
import re
import json
from utils import get_token, set_token, is_repo_exists, get_user_agent, get_download_file, list_uniq
from stkey import read_safetensors_key, read_safetensors_metadata, validate_keys, write_safetensors_key
TEMP_DIR = "."
KEYS_FILES = ["sdxl_keys.txt"]
def parse_urls(s):
url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+"
try:
urls = re.findall(url_pattern, s)
return list(urls)
except Exception:
return []
def to_urls(l: list[str]):
return "\n".join(l)
def uniq_urls(s):
return to_urls(list_uniq(parse_urls(s)))
def upload_safetensors_to_repo(filename, repo_id, repo_type, is_private, progress=gr.Progress(track_tqdm=True)):
output_filename = Path(filename).name
hf_token = get_token()
api = HfApi(token=hf_token)
try:
if not is_repo_exists(repo_id, repo_type): api.create_repo(repo_id=repo_id, repo_type=repo_type, token=hf_token, private=is_private)
progress(0, desc=f"Start uploading... {filename} to {repo_id}")
api.upload_file(path_or_fileobj=filename, path_in_repo=output_filename, repo_type=repo_type, revision="main", token=hf_token, repo_id=repo_id)
progress(1, desc="Uploaded.")
url = hf_hub_url(repo_id=repo_id, repo_type=repo_type, filename=output_filename)
except Exception as e:
print(f"Error: Failed to upload to {repo_id}. {e}")
gr.Warning(f"Error: Failed to upload to {repo_id}. {e}")
return None
finally:
Path(filename).unlink()
return url
def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
download_dir = TEMP_DIR
progress(0, desc=f"Start downloading... {dl_url}")
output_filename = get_download_file(download_dir, dl_url, civitai_key)
return output_filename
def get_stkey(filename: str, is_validate: bool=True, rfile: str=KEYS_FILES[0], progress=gr.Progress(track_tqdm=True)):
paths = []
metadata = {}
keys = []
missing = []
added = []
try:
progress(0, desc=f"Loading keys... {filename}")
keys = read_safetensors_key(filename)
if len(keys) == 0: raise Exception("No keys found.")
progress(0.5, desc=f"Checking keys... {filename}")
if write_safetensors_key(keys, str(Path(filename).stem + ".txt"), is_validate, rfile):
paths.append(str(Path(filename).stem + ".txt"))
paths.append(str(Path(filename).stem + "_missing.txt"))
paths.append(str(Path(filename).stem + "_added.txt"))
missing, added = validate_keys(keys, rfile)
metadata = read_safetensors_metadata(filename)
except Exception as e:
print(f"Error: Failed check {filename}. {e}")
gr.Warning(f"Error: Failed check {filename}. {e}")
finally:
Path(filename).unlink()
return paths, metadata, keys, missing, added
def stkey_gr(dl_url: str, civitai_key: str, hf_token: str, urls: list[str], files: list[str],
is_validate=True, rfile=KEYS_FILES[0], progress=gr.Progress(track_tqdm=True)):
if hf_token: set_token(hf_token)
else: set_token(os.environ.get("HF_TOKEN")) # default huggingface token
if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
dl_urls = parse_urls(dl_url)
if not urls: urls = []
if not files: files = []
metadata = {}
keys = []
missing = []
added = []
for u in dl_urls:
file = download_file(u, civitai_key)
if not Path(file).exists() or not Path(file).is_file(): continue
paths, metadata, keys, missing, added = get_stkey(file, is_validate, rfile)
if len(paths) != 0: files.extend(paths)
progress(1, desc="Processing...")
gc.collect()
return gr.update(value=urls, choices=urls), gr.update(value=files), gr.update(visible=True), metadata, keys, missing, added