John6666's picture
Upload 15 files
c7c9dd3 verified
raw
history blame
3.58 kB
import gradio as gr
import os
from pathlib import Path
import gc
import re
import shutil
from utils import set_token, get_download_file, list_uniq
from stkey import read_safetensors_key, read_safetensors_metadata, validate_keys, write_safetensors_key
TEMP_DIR = "."
KEYS_DIR = "keys"
KEYS_FILES = [f"{KEYS_DIR}/sdxl_keys.txt"]
DEFAULT_KEYS_FILE = f"{KEYS_DIR}/sdxl_keys.txt"
def update_keys_files():
global KEYS_FILES
files = []
for file in Path(KEYS_DIR).glob("*.txt"):
files.append(str(file))
KEYS_FILES = files
update_keys_files()
def upload_keys_file(path: str):
global KEYS_FILES
newpath = str(Path(KEYS_DIR, Path(path).stem + ".txt"))
if not Path(newpath).exists(): shutil.copy(str(Path(path)), newpath)
update_keys_files()
return gr.update(choices=KEYS_FILES)
def parse_urls(s):
url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+"
try:
urls = re.findall(url_pattern, s)
return list(urls)
except Exception:
return []
def to_urls(l: list[str]):
return "\n".join(l)
def uniq_urls(s):
return to_urls(list_uniq(parse_urls(s)))
def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
download_dir = TEMP_DIR
progress(0, desc=f"Start downloading... {dl_url}")
output_filename = get_download_file(download_dir, dl_url, civitai_key)
return output_filename
def get_stkey(filename: str, is_validate: bool=True, rfile: str=KEYS_FILES[0], progress=gr.Progress(track_tqdm=True)):
paths = []
metadata = {}
keys = []
missing = []
added = []
try:
progress(0, desc=f"Loading keys... {filename}")
keys = read_safetensors_key(filename)
if len(keys) == 0: raise Exception("No keys found.")
progress(0.5, desc=f"Checking keys... {filename}")
if write_safetensors_key(keys, str(Path(filename).stem + ".txt"), is_validate, rfile):
paths.append(str(Path(filename).stem + ".txt"))
paths.append(str(Path(filename).stem + "_missing.txt"))
paths.append(str(Path(filename).stem + "_added.txt"))
missing, added = validate_keys(keys, rfile)
metadata = read_safetensors_metadata(filename)
except Exception as e:
print(f"Error: Failed check {filename}. {e}")
gr.Warning(f"Error: Failed check {filename}. {e}")
finally:
Path(filename).unlink()
return paths, metadata, keys, missing, added
def stkey_gr(dl_url: str, civitai_key: str, hf_token: str, urls: list[str], files: list[str],
is_validate=True, rfile=KEYS_FILES[0], progress=gr.Progress(track_tqdm=True)):
if not hf_token: hf_token = os.environ.get("HF_TOKEN") # default huggingface token
set_token(hf_token)
if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
dl_urls = parse_urls(dl_url)
if not urls: urls = []
if not files: files = []
metadata = {}
keys = []
missing = []
added = []
for u in dl_urls:
file = download_file(u, civitai_key)
if not Path(file).exists() or not Path(file).is_file(): continue
paths, metadata, keys, missing, added = get_stkey(file, is_validate, rfile)
if len(paths) != 0: files.extend(paths)
progress(1, desc="Processing...")
gc.collect()
return gr.update(value=urls, choices=urls), gr.update(value=files), gr.update(visible=False), metadata, keys, missing, added