File size: 14,147 Bytes
4120479
7c6dd97
c694422
ab7c449
9b729f7
ca25d4d
af5ea8a
c6550c9
5b09dff
4120479
 
 
1475e41
3eb8dac
7c6dd97
5042a41
4120479
 
 
af5ea8a
 
 
 
 
 
 
 
 
 
 
 
4120479
 
 
 
20706a7
4120479
 
 
 
 
 
 
 
 
 
 
 
5042a41
 
9b729f7
 
ba35348
5042a41
4aa6570
5042a41
735a271
5042a41
 
 
 
 
2f40f84
4120479
edf408e
ce04d24
 
2130441
b819231
ce04d24
 
b0fbf54
4120479
b0fbf54
b66cade
bc0b5e4
ca25d4d
 
bc0b5e4
ca25d4d
 
 
 
92c7c82
bc5aade
 
b0fbf54
 
7b45d83
47b5daf
4120479
 
ad8076c
3eb8dac
5042a41
3ff5e41
a1a833d
5042a41
 
c7fa047
a1a833d
ffede06
fb4901e
ffede06
fb4901e
a1a833d
 
 
c694422
fb4901e
 
1323497
fb4901e
 
1323497
 
 
 
a1a833d
ffede06
0fac10f
a1a833d
c7fa047
3b6af48
4120479
3b6af48
4120479
a1a833d
 
4120479
a5454cf
af5ea8a
 
 
 
 
ce1c1c2
4120479
 
 
145660e
786f873
00e9fdd
786f873
 
4120479
ea9cf0a
4120479
 
0052f82
4120479
0052f82
 
4120479
 
9155e06
786f873
 
a56c826
 
4120479
06bf887
9155e06
965ea29
4120479
af5ea8a
 
 
0bd58d3
af5ea8a
 
 
8428948
98e159e
af5ea8a
9107950
 
 
4120479
 
 
cad7f41
b819231
4120479
 
 
0d23c84
 
ce04d24
5b09dff
786f873
fdf6122
d1f8613
ce04d24
 
 
5b09dff
786f873
fdf6122
d1f8613
ce04d24
edf408e
ce04d24
7b45d83
0d23c84
b0fbf54
 
3b258c8
ce04d24
29b4824
92c7c82
b48fe41
b157e21
35002e5
b48fe41
b157e21
35002e5
b48fe41
ca25d4d
 
 
 
4120479
 
 
 
0eecc9a
 
 
cad7f41
b6fa736
9155e06
 
4120479
9107950
03d99f7
 
9107950
03d99f7
 
8428948
98e159e
 
ca25d4d
0eecc9a
0a50fe3
d92fc15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import gradio as gr
from time import sleep, time
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download, CommitScheduler
from safetensors.torch import load_file
from share_btn import community_icon_html, loading_icon_html, share_js
from uuid import uuid4
from pathlib import Path
from PIL import Image
import torch
import json
import random
import copy
import gc
import pickle
import spaces

lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")

IMAGE_DATASET_DIR = Path("image_dataset") / f"train-{uuid4()}"
IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl"

scheduler = CommitScheduler(
    repo_id="multimodalart/lora-fusing-preferences",
    repo_type="dataset",
    folder_path=IMAGE_DATASET_DIR,
    path_in_repo=IMAGE_DATASET_DIR.name,
    every=10
)

with open(lora_list, "r") as file:
    data = json.load(file)
    sdxl_loras = [
        {
            "image": item["image"] if item["image"].startswith("https://") else f'https://huggingface.co/spaces/multimodalart/LoraTheExplorer/resolve/main/{item["image"]}',
            "title": item["title"],
            "repo": item["repo"],
            "trigger_word": item["trigger_word"],
            "weights": item["weights"],
            "is_compatible": item["is_compatible"],
            "is_pivotal": item.get("is_pivotal", False),
            "text_embedding_weights": item.get("text_embedding_weights", None),
            "is_nc": item.get("is_nc", False)
        }
        for item in data
    ]

state_dicts = {}

for item in sdxl_loras:
    saved_name = hf_hub_download(item["repo"], item["weights"])
    
    if not saved_name.endswith('.safetensors'):
        state_dict = torch.load(saved_name, map_location=torch.device('cpu'))
    else:
        state_dict = load_file(saved_name, device="cpu")
    
    state_dicts[item["repo"]] = {
        "saved_name": saved_name,
        "state_dict": state_dict
    }

css = '''
.gradio-container{max-width: 650px! important}
#title{text-align:center;}
#title h1{font-size: 250%}
.selected_random img{object-fit: cover}
.selected_random [data-testid="block-label"] span{display: none}
.plus_column{align-self: center}
.plus_button{font-size: 235% !important; text-align: center;margin-bottom: 19px}
#prompt{padding: 0 0 1em 0}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button{position: absolute;margin-top: 25.8px;right: 0;margin-right: 0.75em;border-bottom-left-radius: 0px;border-top-left-radius: 0px}
.random_column{align-self: center; align-items: center;gap: 0.5em !important}
#share-btn-container{padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;margin-top: 0.35em;}
div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
#share-btn-container:hover {background-color: #060606}
#share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;font-size: 15px;}
#share-btn * {all: unset}
#share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
#share-btn-container .wrap {display: none !important}
#share-btn-container.hidden {display: none!important}
#post_gen_info{margin-top: .5em}
#thumbs_up_clicked{background:green}
#thumbs_down_clicked{background:red}
.title_lora a{color: var(--body-text-color) !important; opacity:0.6}
#prompt_area .form{border:0}
#reroll_button{position: absolute;right: 0;flex-grow: 1;min-width: 75px;padding: .1em}
.pending .min {min-height: auto}
'''

original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

@spaces.GPU
def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, seed=-1):
  
  repo_id_1 = shuffled_items[0]['repo']
  repo_id_2 = shuffled_items[1]['repo']
  print("Loading state dicts...")
  start_time = time()  
  state_dict_1 = copy.deepcopy(state_dicts[repo_id_1]["state_dict"])
  state_dict_1 = {k: v.to(device="cuda", dtype=torch.float16) for k,v in state_dict_1.items() if torch.is_tensor(v)}
  state_dict_2 = copy.deepcopy(state_dicts[repo_id_2]["state_dict"])
  state_dict_2 = {k: v.to(device="cuda", dtype=torch.float16) for k,v in state_dict_2.items() if torch.is_tensor(v)}
  state_dict_time = time() - start_time
  print(f"State Dict time: {state_dict_time}")  
  start_time = time()
  unet = copy.deepcopy(original_pipe.unet)
  text_encoder=copy.deepcopy(original_pipe.text_encoder)
  text_encoder_2=copy.deepcopy(original_pipe.text_encoder_2)
  pipe = StableDiffusionXLPipeline(vae=original_pipe.vae,
                                   text_encoder=text_encoder,
                                   text_encoder_2=text_encoder_2,
                                   scheduler=original_pipe.scheduler,
                                   tokenizer=original_pipe.tokenizer,
                                   tokenizer_2=original_pipe.tokenizer_2,
                                   unet=unet)
  pickle_time = time() - start_time
  print(f"copy time: {pickle_time}")  
  pipe.to("cuda") 
  start_time = time()
  print("Loading LoRA weights...")
  pipe.load_lora_weights(state_dict_1, low_cpu_mem_usage=True)
  pipe.fuse_lora(lora_1_scale)
  pipe.load_lora_weights(state_dict_2, low_cpu_mem_usage=True)
  pipe.fuse_lora(lora_2_scale)
  lora_time = time() - start_time
  print(f"Loaded LoRAs time: {lora_time}") 
  if negative_prompt == "":
    negative_prompt = None
          
  if(seed < 0):
      seed = random.randint(0, 2147483647)
  generator = torch.Generator(device="cuda").manual_seed(seed)
  image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768, generator=generator).images[0]
  return image, gr.update(visible=True), seed, gr.update(visible=True, interactive=True), gr.update(visible=False), gr.update(visible=True, interactive=True), gr.update(visible=False)

def get_description(item):
      trigger_word = item["trigger_word"]
      return f"Trigger: `{trigger_word}`" if trigger_word else "No trigger, applied automatically", trigger_word

def truncate_string(s, max_length=29):
    return s[:max_length - 3] + "..." if len(s) > max_length else s

def shuffle_images():
    compatible_items = [item for item in sdxl_loras if item['is_compatible']]
    random.shuffle(compatible_items)
    two_shuffled_items = compatible_items[:2]
    title_1 = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
    title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])
    repo_id_1 = gr.update(value=two_shuffled_items[0]['repo'])
    repo_id_2 = gr.update(value=two_shuffled_items[1]['repo'])
    description_1, trigger_word_1 = get_description(two_shuffled_items[0])
    description_2, trigger_word_2 = get_description(two_shuffled_items[1])

    lora_1_link = f"[{truncate_string(two_shuffled_items[0]['repo'])}](https://huggingface.co/{two_shuffled_items[0]['repo']}) ✨"
    lora_2_link = f"[{truncate_string(two_shuffled_items[1]['repo'])}](https://huggingface.co/{two_shuffled_items[1]['repo']}) ✨"
    prompt_description_1 = gr.update(value=description_1, visible=True)
    prompt_description_2 = gr.update(value=description_2, visible=True)
    prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
    scale = gr.update(value=0.7)

    return lora_1_link, title_1, prompt_description_1, repo_id_1, lora_2_link, title_2, prompt_description_2, repo_id_2, prompt, two_shuffled_items, scale, scale

def save_preferences(lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, generated_image, thumbs_direction, seed):
    image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png"
    with scheduler.lock:
        Image.fromarray(generated_image).save(image_path)
        with IMAGE_JSONL_PATH.open("a") as f:
            json.dump({"prompt": prompt, "file_name":image_path.name, "lora_1_id": lora_2_id, "lora_1_scale": lora_1_scale, "lora_2_id": lora_2_id, "lora_2_scale": lora_2_scale, "thumbs_direction": thumbs_direction, "seed": seed}, f)
            f.write("\n")
    
    return gr.update(visible=False), gr.update(visible=True), gr.update(interactive=False)

def hide_post_gen_info():
    return gr.update(visible=False)

with gr.Blocks(css=css) as demo:
  shuffled_items = gr.State()
  title = gr.HTML(
        '''<h1>LoRA Roulette 🎰</h1>
        <p>This random LoRAs are loaded into SDXL, can you find a fun way to combine them? 🎨</p>
        ''',
        elem_id="title"
  )
  with gr.Column():
    with gr.Column(min_width=10, scale=16, elem_classes="plus_column"):
        with gr.Row():
            with gr.Column(min_width=10, scale=4, elem_classes="random_column"):
              lora_1_link = gr.Markdown(elem_classes="title_lora")
              lora_1 = gr.Image(interactive=False, height=150, elem_classes="selected_random", elem_id="randomLoRA_1", show_share_button=False, show_download_button=False)
              lora_1_id = gr.Textbox(visible=False, elem_id="random_lora_1_id")
              lora_1_prompt = gr.Markdown(visible=False)
            with gr.Column(min_width=10, scale=1, elem_classes="plus_column"):
              plus = gr.HTML("+", elem_classes="plus_button")
            with gr.Column(min_width=10, scale=4, elem_classes="random_column"):
              lora_2_link = gr.Markdown(elem_classes="title_lora")
              lora_2 = gr.Image(interactive=False, height=150, elem_classes="selected_random", elem_id="randomLoRA_2", show_share_button=False, show_download_button=False)
              lora_2_id = gr.Textbox(visible=False, elem_id="random_lora_2_id")
              lora_2_prompt = gr.Markdown(visible=False)
            with gr.Column(min_width=10, scale=2, elem_classes="plus_column"):
               equal = gr.HTML("=", elem_classes="plus_button")
            shuffle_button = gr.Button("🎲  reroll", elem_id="reroll_button")
    with gr.Column(min_width=10, scale=14):
        with gr.Box(elem_id="generate_area"):
            with gr.Row(elem_id="prompt_area"):
                prompt = gr.Textbox(label="Your prompt", info="Rearrange the trigger words into a coherent prompt", show_label=False, interactive=True, elem_id="prompt")
                run_btn = gr.Button("Run", elem_id="run_button")
            output_image = gr.Image(label="Output", height=355, elem_id="output_image", interactive=False)
            with gr.Row(visible=False, elem_id="post_gen_info") as post_gen_info:
                with gr.Column(min_width=10):
                    thumbs_up = gr.Button("👍", elem_id="thumbs_up_unclicked")
                    thumbs_up_clicked = gr.Button("👍", elem_id="thumbs_up_clicked", interactive=False, visible=False)
                with gr.Column(min_width=10):
                    thumbs_down = gr.Button("👎", elem_id="thumbs_down_unclicked")
                    thumbs_down_clicked = gr.Button("👎", elem_id="thumbs_down_clicked", interactive=False, visible=False)
                with gr.Column(min_width=10):
                    with gr.Group(elem_id="share-btn-container") as share_group:
                        community_icon = gr.HTML(community_icon_html)
                        loading_icon = gr.HTML(loading_icon_html)
                        share_button = gr.Button("Share to community", elem_id="share-btn")
  with gr.Accordion("Advanced settings", open=False):
    with gr.Row():
      lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
      lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
    negative_prompt = gr.Textbox(label="Negative prompt")
    seed = gr.Slider(label="Seed", info="-1 denotes a random seed", minimum=-1, maximum=2147483647, value=-1)
    last_used_seed = gr.Number(label="Last used seed", info="The seed used in the last generation", minimum=0, maximum=2147483647, value=-1, interactive=False)
  gr.Markdown("Generate with intent in [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer), but remember: sometimes restrictions flourish creativity 🌸")
    
  demo.load(shuffle_images, inputs=[], outputs=[lora_1_link, lora_1, lora_1_prompt, lora_1_id, lora_2_link, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
  shuffle_button.click(shuffle_images, outputs=[lora_1_link, lora_1, lora_1_prompt, lora_1_id, lora_2_link, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")

  run_btn.click(hide_post_gen_info, outputs=[post_gen_info], queue=False).then(merge_and_run,
                inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale, seed],
                outputs=[output_image, post_gen_info, last_used_seed, thumbs_up, thumbs_up_clicked, thumbs_down, thumbs_down_clicked])
  prompt.submit(hide_post_gen_info, outputs=[post_gen_info], queue=False).then(merge_and_run,
                inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale, seed],
                outputs=[output_image, post_gen_info, last_used_seed, thumbs_up, thumbs_up_clicked, thumbs_down, thumbs_down_clicked])
  
  thumbs_up.click(save_preferences, inputs=[lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, output_image, gr.State("up"), seed], outputs=[thumbs_up, thumbs_up_clicked, thumbs_down])
  thumbs_down.click(save_preferences, inputs=[lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, output_image, gr.State("down"), seed], outputs=[thumbs_down, thumbs_down_clicked, thumbs_up])
  share_button.click(None, [], [], _js=share_js)

demo.queue(concurrency_count=2)
demo.launch()