File size: 3,583 Bytes
7847c2d
 
 
 
 
e9d9ca7
 
7847c2d
 
 
 
e9d9ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7847c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7c9dd3
 
7847c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9d9ca7
7847c2d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import os
from pathlib import Path
import gc
import re
import shutil
from utils import set_token, get_download_file, list_uniq
from stkey import read_safetensors_key, read_safetensors_metadata, validate_keys, write_safetensors_key


TEMP_DIR = "."
KEYS_DIR = "keys"
KEYS_FILES = [f"{KEYS_DIR}/sdxl_keys.txt"]
DEFAULT_KEYS_FILE = f"{KEYS_DIR}/sdxl_keys.txt"


def update_keys_files():
    global KEYS_FILES
    files = []
    for file in Path(KEYS_DIR).glob("*.txt"):
        files.append(str(file))
    KEYS_FILES = files


update_keys_files()


def upload_keys_file(path: str):
    global KEYS_FILES
    newpath = str(Path(KEYS_DIR, Path(path).stem + ".txt"))
    if not Path(newpath).exists(): shutil.copy(str(Path(path)), newpath)
    update_keys_files()
    return gr.update(choices=KEYS_FILES)


def parse_urls(s):
    url_pattern = "https?://[\\w/:%#\\$&\\?\\(\\)~\\.=\\+\\-]+"
    try:
        urls = re.findall(url_pattern, s)
        return list(urls)
    except Exception:
        return []


def to_urls(l: list[str]):
    return "\n".join(l)


def uniq_urls(s):
    return to_urls(list_uniq(parse_urls(s)))


def download_file(dl_url, civitai_key, progress=gr.Progress(track_tqdm=True)):
    download_dir = TEMP_DIR
    progress(0, desc=f"Start downloading... {dl_url}")
    output_filename = get_download_file(download_dir, dl_url, civitai_key)
    return output_filename


def get_stkey(filename: str, is_validate: bool=True, rfile: str=KEYS_FILES[0], progress=gr.Progress(track_tqdm=True)):
    paths = []
    metadata = {}
    keys = []
    missing = []
    added = []
    try:
        progress(0, desc=f"Loading keys... {filename}")
        keys = read_safetensors_key(filename)
        if len(keys) == 0: raise Exception("No keys found.")
        progress(0.5, desc=f"Checking keys... {filename}")
        if write_safetensors_key(keys, str(Path(filename).stem + ".txt"), is_validate, rfile):
            paths.append(str(Path(filename).stem + ".txt"))
            paths.append(str(Path(filename).stem + "_missing.txt"))
            paths.append(str(Path(filename).stem + "_added.txt"))
        missing, added = validate_keys(keys, rfile)
        metadata = read_safetensors_metadata(filename)
    except Exception as e:
        print(f"Error: Failed check {filename}. {e}")
        gr.Warning(f"Error: Failed check {filename}. {e}")
    finally:
        Path(filename).unlink()
        return paths, metadata, keys, missing, added


def stkey_gr(dl_url: str, civitai_key: str, hf_token: str, urls: list[str], files: list[str], 

             is_validate=True, rfile=KEYS_FILES[0], progress=gr.Progress(track_tqdm=True)):
    if not hf_token: hf_token = os.environ.get("HF_TOKEN") # default huggingface token
    set_token(hf_token)
    if not civitai_key: civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
    dl_urls = parse_urls(dl_url)
    if not urls: urls = []
    if not files: files = []
    metadata = {}
    keys = []
    missing = []
    added = []
    for u in dl_urls:
        file = download_file(u, civitai_key)
        if not Path(file).exists() or not Path(file).is_file(): continue
        paths, metadata, keys, missing, added = get_stkey(file, is_validate, rfile)
        if len(paths) != 0: files.extend(paths)
    progress(1, desc="Processing...")
    gc.collect()
    return gr.update(value=urls, choices=urls), gr.update(value=files), gr.update(visible=False), metadata, keys, missing, added