John6666 commited on
Commit
26a0cbe
1 Parent(s): 4059875

Upload 4 files

Browse files
Files changed (3) hide show
  1. app.py +15 -8
  2. flux.py +2 -75
  3. modutils.py +74 -7
app.py CHANGED
@@ -21,7 +21,8 @@ from mod import (clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists
21
  get_control_union_mode, set_control_union_mode, get_control_params, translate_to_en)
22
  from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
23
  download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
24
- update_loras, get_t2i_model_info)
 
25
  from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
26
  from tagger.fl2flux import predict_tags_fl2_flux
27
 
@@ -111,6 +112,11 @@ def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, prog
111
 
112
  change_base_model.zerogpu = True
113
 
 
 
 
 
 
114
  class calculateDuration:
115
  def __init__(self, activity_name=""):
116
  self.activity_name = activity_name
@@ -130,7 +136,7 @@ class calculateDuration:
130
  def download_file(url, directory=None):
131
  if directory is None:
132
  directory = os.getcwd() # Use current working directory if not specified
133
-
134
  # Get the filename from the URL
135
  filename = url.split('/')[-1]
136
 
@@ -146,7 +152,7 @@ def download_file(url, directory=None):
146
  file.write(response.content)
147
 
148
  return filepath
149
-
150
  def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
151
  selected_index = evt.index
152
  selected_indices = selected_indices or []
@@ -243,7 +249,8 @@ def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
243
  existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
244
  if existing_item_index is None:
245
  if repo.endswith(".safetensors") and repo.startswith("http"):
246
- repo = download_file(repo)
 
247
  new_item = {
248
  "image": image if image else "/home/user/app/custom.png",
249
  "title": title,
@@ -504,14 +511,14 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
504
  print(f"Lora Path: {lora_path}")
505
  if image_input is not None:
506
  if weight_name:
507
- pipe_i2i.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True, adapter_name=lora_name)
508
  else:
509
- pipe_i2i.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name)
510
  else:
511
  if weight_name:
512
- pipe.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True, adapter_name=lora_name)
513
  else:
514
- pipe.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name)
515
  print("Loaded LoRAs:", lora_names)
516
  if image_input is not None:
517
  pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
 
21
  get_control_union_mode, set_control_union_mode, get_control_params, translate_to_en)
22
  from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
23
  download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
24
+ update_loras)
25
+ from modutils import get_t2i_model_info, get_download_file
26
  from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
27
  from tagger.fl2flux import predict_tags_fl2_flux
28
 
 
112
 
113
  change_base_model.zerogpu = True
114
 
115
+ def download_file_mod(url, directory=os.getcwd()):
116
+ path = get_download_file(directory, url)
117
+ if not path: raise Exception(f"Download error: {url}")
118
+ return path
119
+
120
  class calculateDuration:
121
  def __init__(self, activity_name=""):
122
  self.activity_name = activity_name
 
136
  def download_file(url, directory=None):
137
  if directory is None:
138
  directory = os.getcwd() # Use current working directory if not specified
139
+
140
  # Get the filename from the URL
141
  filename = url.split('/')[-1]
142
 
 
152
  file.write(response.content)
153
 
154
  return filepath
155
+
156
  def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
157
  selected_index = evt.index
158
  selected_indices = selected_indices or []
 
249
  existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
250
  if existing_item_index is None:
251
  if repo.endswith(".safetensors") and repo.startswith("http"):
252
+ #repo = download_file(repo)
253
+ repo = download_file_mod(repo)
254
  new_item = {
255
  "image": image if image else "/home/user/app/custom.png",
256
  "title": title,
 
511
  print(f"Lora Path: {lora_path}")
512
  if image_input is not None:
513
  if weight_name:
514
+ pipe_i2i.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True, adapter_name=lora_name, token=HF_TOKEN)
515
  else:
516
+ pipe_i2i.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name, token=HF_TOKEN)
517
  else:
518
  if weight_name:
519
+ pipe.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True, adapter_name=lora_name, token=HF_TOKEN)
520
  else:
521
+ pipe.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name, token=HF_TOKEN)
522
  print("Loaded LoRAs:", lora_names)
523
  if image_input is not None:
524
  pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
flux.py CHANGED
@@ -17,54 +17,11 @@ from env import (HF_TOKEN, hf_read_token, # to use only for private repos
17
  HF_VAE_PRIVATE_REPO, directory_models, directory_loras, directory_vaes,
18
  download_model_list, download_lora_list, download_vae_list)
19
  from modutils import (to_list, list_uniq, list_sub, get_lora_model_list, download_private_repo,
20
- safe_float, escape_lora_basename, to_lora_key, to_lora_path, get_local_model_list,
21
  get_private_lora_model_lists, get_valid_lora_name, get_valid_lora_path, get_valid_lora_wt,
22
  get_lora_info, normalize_prompt_list, get_civitai_info, search_lora_on_civitai, MODEL_TYPE_DICT)
23
 
24
 
25
- def download_things(directory, url, hf_token="", civitai_api_key=""):
26
- url = url.strip()
27
- if "drive.google.com" in url:
28
- original_dir = os.getcwd()
29
- os.chdir(directory)
30
- os.system(f"gdown --fuzzy {url}")
31
- os.chdir(original_dir)
32
- elif "huggingface.co" in url:
33
- url = url.replace("?download=true", "")
34
- # url = urllib.parse.quote(url, safe=':/') # fix encoding
35
- if "/blob/" in url:
36
- url = url.replace("/blob/", "/resolve/")
37
- user_header = f'"Authorization: Bearer {hf_token}"'
38
- if hf_token:
39
- os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
40
- else:
41
- os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
42
- elif "civitai.com" in url:
43
- if "?" in url:
44
- url = url.split("?")[0]
45
- if civitai_api_key:
46
- url = url + f"?token={civitai_api_key}"
47
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
48
- else:
49
- print("\033[91mYou need an API key to download Civitai models.\033[0m")
50
- else:
51
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
52
-
53
-
54
- def get_model_list(directory_path):
55
- model_list = []
56
- valid_extensions = {'.ckpt' , '.pt', '.pth', '.safetensors', '.bin'}
57
-
58
- for filename in os.listdir(directory_path):
59
- if os.path.splitext(filename)[1] in valid_extensions:
60
- name_without_extension = os.path.splitext(filename)[0]
61
- file_path = os.path.join(directory_path, filename)
62
- # model_list.append((name_without_extension, file_path))
63
- model_list.append(file_path)
64
- print('\033[34mFILE: ' + file_path + '\033[0m')
65
- return model_list
66
-
67
-
68
  # - **Download Models**
69
  download_model = ", ".join(download_model_list)
70
  # - **Download VAEs**
@@ -90,40 +47,10 @@ for url in [url.strip() for url in download_lora.split(',')]:
90
  download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
91
 
92
  lora_model_list = get_lora_model_list()
93
- vae_model_list = get_model_list(directory_vaes)
94
  vae_model_list.insert(0, "None")
95
 
96
 
97
- def is_repo_name(s):
98
- import re
99
- return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
100
-
101
-
102
- def get_t2i_model_info(repo_id: str):
103
- api = HfApi(token=HF_TOKEN)
104
- try:
105
- if not is_repo_name(repo_id): return ""
106
- model = api.model_info(repo_id=repo_id, timeout=5.0)
107
- except Exception as e:
108
- print(f"Error: Failed to get {repo_id}'s info.")
109
- print(e)
110
- return ""
111
- if model.private or model.gated: return ""
112
- tags = model.tags
113
- info = []
114
- url = f"https://huggingface.co/{repo_id}/"
115
- if not 'diffusers' in tags: return ""
116
- for k, v in MODEL_TYPE_DICT.items():
117
- if k in tags: info.append(v)
118
- if model.card_data and model.card_data.tags:
119
- info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
120
- info.append(f"DLs: {model.downloads}")
121
- info.append(f"likes: {model.likes}")
122
- info.append(model.last_modified.strftime("lastmod: %Y-%m-%d"))
123
- md = f"Model Info: {', '.join(info)}, [Model Repo]({url})"
124
- return gr.update(value=md)
125
-
126
-
127
  private_lora_dict = {"": ["", "", "", "", ""]}
128
  try:
129
  with open('lora_dict.json', encoding='utf-8') as f:
 
17
  HF_VAE_PRIVATE_REPO, directory_models, directory_loras, directory_vaes,
18
  download_model_list, download_lora_list, download_vae_list)
19
  from modutils import (to_list, list_uniq, list_sub, get_lora_model_list, download_private_repo,
20
+ safe_float, escape_lora_basename, to_lora_key, to_lora_path, get_local_model_list, download_things,
21
  get_private_lora_model_lists, get_valid_lora_name, get_valid_lora_path, get_valid_lora_wt,
22
  get_lora_info, normalize_prompt_list, get_civitai_info, search_lora_on_civitai, MODEL_TYPE_DICT)
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # - **Download Models**
26
  download_model = ", ".join(download_model_list)
27
  # - **Download VAEs**
 
47
  download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
48
 
49
  lora_model_list = get_lora_model_list()
50
+ vae_model_list = get_local_model_list(directory_vaes)
51
  vae_model_list.insert(0, "None")
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  private_lora_dict = {"": ["", "", "", "", ""]}
55
  try:
56
  with open('lora_dict.json', encoding='utf-8') as f:
modutils.py CHANGED
@@ -1,10 +1,12 @@
1
  import spaces
2
  import json
3
  import gradio as gr
4
- from huggingface_hub import HfApi
5
  import os
6
  from pathlib import Path
7
  from PIL import Image
 
 
 
8
 
9
 
10
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
@@ -28,7 +30,7 @@ def to_list(s):
28
 
29
 
30
  def list_uniq(l):
31
- return sorted(set(l), key=l.index)
32
 
33
 
34
  def list_sub(a, b):
@@ -61,6 +63,48 @@ def get_local_model_list(dir_path):
61
  return model_list
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def download_things(directory, url, hf_token="", civitai_api_key=""):
65
  url = url.strip()
66
  if "drive.google.com" in url:
@@ -73,11 +117,7 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
73
  # url = urllib.parse.quote(url, safe=':/') # fix encoding
74
  if "/blob/" in url:
75
  url = url.replace("/blob/", "/resolve/")
76
- user_header = f'"Authorization: Bearer {hf_token}"'
77
- if hf_token:
78
- os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
79
- else:
80
- os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
81
  elif "civitai.com" in url:
82
  if "?" in url:
83
  url = url.split("?")[0]
@@ -90,6 +130,33 @@ def download_things(directory, url, hf_token="", civitai_api_key=""):
90
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def escape_lora_basename(basename: str):
94
  return basename.replace(".", "_").replace(" ", "_").replace(",", "")
95
 
 
1
  import spaces
2
  import json
3
  import gradio as gr
 
4
  import os
5
  from pathlib import Path
6
  from PIL import Image
7
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
8
+ import urllib.parse
9
+ import re
10
 
11
 
12
  from env import (HF_LORA_PRIVATE_REPOS1, HF_LORA_PRIVATE_REPOS2,
 
30
 
31
 
32
  def list_uniq(l):
33
+ return sorted(set(l), key=l.index)
34
 
35
 
36
  def list_sub(a, b):
 
63
  return model_list
64
 
65
 
66
+ def get_token():
67
+ try:
68
+ token = HfFolder.get_token()
69
+ except Exception:
70
+ token = ""
71
+ return token
72
+
73
+
74
+ def set_token(token):
75
+ try:
76
+ HfFolder.save_token(token)
77
+ except Exception:
78
+ print(f"Error: Failed to save token.")
79
+
80
+
81
+ set_token(HF_TOKEN)
82
+
83
+
84
+ def split_hf_url(url: str):
85
+ try:
86
+ s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0])
87
+ if len(s) < 4: return "", "", "", ""
88
+ repo_id = s[1]
89
+ repo_type = "dataset" if s[0] == "datasets" else "model"
90
+ subfolder = urllib.parse.unquote(s[2]) if s[2] else None
91
+ filename = urllib.parse.unquote(s[3])
92
+ return repo_id, filename, subfolder, repo_type
93
+ except Exception as e:
94
+ print(e)
95
+
96
+
97
+ def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
98
+ hf_token = get_token()
99
+ repo_id, filename, subfolder, repo_type = split_hf_url(url)
100
+ try:
101
+ print(f"Downloading {url} to {directory}")
102
+ if subfolder is not None: hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
103
+ else: hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
104
+ except Exception as e:
105
+ print(f"Failed to download: {e}")
106
+
107
+
108
  def download_things(directory, url, hf_token="", civitai_api_key=""):
109
  url = url.strip()
110
  if "drive.google.com" in url:
 
117
  # url = urllib.parse.quote(url, safe=':/') # fix encoding
118
  if "/blob/" in url:
119
  url = url.replace("/blob/", "/resolve/")
120
+ download_hf_file(directory, url)
 
 
 
 
121
  elif "civitai.com" in url:
122
  if "?" in url:
123
  url = url.split("?")[0]
 
130
  os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
131
 
132
 
133
+ def get_download_file(temp_dir, url, civitai_key="", progress=gr.Progress(track_tqdm=True)):
134
+ if not "http" in url and is_repo_name(url) and not Path(url).exists():
135
+ print(f"Use HF Repo: {url}")
136
+ new_file = url
137
+ elif not "http" in url and Path(url).exists():
138
+ print(f"Use local file: {url}")
139
+ new_file = url
140
+ elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
141
+ print(f"File to download alreday exists: {url}")
142
+ new_file = f"{temp_dir}/{url.split('/')[-1]}"
143
+ else:
144
+ print(f"Start downloading: {url}")
145
+ before = get_local_model_list(temp_dir)
146
+ try:
147
+ download_things(temp_dir, url.strip(), HF_TOKEN, civitai_key)
148
+ except Exception:
149
+ print(f"Download failed: {url}")
150
+ return ""
151
+ after = get_local_model_list(temp_dir)
152
+ new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
153
+ if not new_file:
154
+ print(f"Download failed: {url}")
155
+ return ""
156
+ print(f"Download completed: {url}")
157
+ return new_file
158
+
159
+
160
  def escape_lora_basename(basename: str):
161
  return basename.replace(".", "_").replace(" ", "_").replace(",", "")
162