File size: 12,999 Bytes
2b8b77d
 
1eb5467
e1fcf74
2b8b77d
 
 
 
e2371c5
 
 
16908f1
003a054
978580d
 
 
 
 
003a054
 
c6dfa2b
e1fcf74
 
1314d69
f54a55a
 
 
592c68b
292c38f
e1fcf74
e2371c5
592c68b
 
 
 
 
 
7b28dab
145506a
718ba97
7b28dab
9303de6
efbe74e
1314d69
 
145506a
f217e4d
9a397ea
 
592c68b
978580d
718ba97
00fc70b
 
f217e4d
1eb5467
0cbf06a
ca9e441
 
0cbf06a
 
c464ec4
 
0cbf06a
 
 
 
 
 
f40fb7c
cd2465c
dc2976a
f217e4d
ca9e441
f217e4d
64b9ad0
cddcfed
1eb5467
9724323
1314d69
a61f19c
dc2976a
16908f1
 
 
0be383f
ea162a8
 
16908f1
978580d
 
1eb5467
978580d
1eb5467
16908f1
cddcfed
16908f1
 
 
 
5fd0376
c292764
0be383f
ea162a8
0be383f
592c68b
 
718ba97
 
1314d69
cd2465c
 
 
6b97640
 
cd2465c
9272473
cd2465c
 
 
0ac7cc0
6b97640
 
cd2465c
 
 
07b6090
6b97640
 
cd2465c
 
 
9272473
 
07b6090
9272473
43fbea6
9272473
cd2465c
29017ec
 
e517d30
 
29017ec
 
 
 
 
 
 
 
 
 
 
 
 
145506a
0cbf06a
29017ec
f217e4d
 
 
1eb5467
 
d5a8945
7b9e6e4
1eb5467
718ba97
 
f217e4d
592c68b
ccc38b8
 
145506a
 
 
 
ccc38b8
 
ca9e441
ccc38b8
 
 
 
9085d76
0be383f
ccc38b8
 
 
 
 
 
c292764
ccc38b8
 
 
 
 
 
 
0cbf06a
ccc38b8
164edec
592c68b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4f255d
145506a
0cbf06a
1eb5467
718ba97
 
cddcfed
592c68b
 
 
 
 
164edec
cd2465c
50d6862
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import gradio as gr
import spaces
from clip_slider_pipeline import CLIPSliderFlux
from diffusers import FluxPipeline, AutoencoderTiny
import torch
import numpy as np
import cv2
from PIL import Image
from diffusers.utils import load_image
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
from diffusers.models.controlnet_flux import FluxControlNetModel
from diffusers.utils import export_to_gif

def process_controlnet_img(image):
    controlnet_img = np.array(image)
    controlnet_img = cv2.Canny(controlnet_img, 100, 200)
    controlnet_img = HWC3(controlnet_img)
    controlnet_img = Image.fromarray(controlnet_img)

# load pipelines
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell",
                                    vae=taef1,
                                    torch_dtype=torch.bfloat16)

pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
#pipe.enable_model_cpu_offload()
clip_slider = CLIPSliderFlux(pipe, device=torch.device("cuda"))


base_model = 'black-forest-labs/FLUX.1-schnell'
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Canny-alpha'
# controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
# pipe_controlnet = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
# t5_slider_controlnet = T5SliderFlux(sd_pipe=pipe_controlnet,device=torch.device("cuda"))

@spaces.GPU(duration=200)
def generate(concept_1, concept_2, scale, prompt, seed, recalc_directions, iterations, steps, interm_steps, guidance_scale,
             x_concept_1, x_concept_2, 
             avg_diff_x, 
             img2img_type = None, img = None, 
             controlnet_scale= None, ip_adapter_scale=None,
             
             ):
    slider_x = [concept_1, concept_2]
    # check if avg diff for directions need to be re-calculated
    print("slider_x", slider_x)
    print("x_concept_1", x_concept_1, "x_concept_2", x_concept_2)
    #torch.manual_seed(seed)
    
    if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]) or recalc_directions:
        #avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1], num_iterations=iterations).to(torch.float16)
        avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1], num_iterations=iterations)
        x_concept_1, x_concept_2 = slider_x[0], slider_x[1]

    images = []
    high_scale = scale
    low_scale = -1 * scale
    for i in range(interm_steps):
        cur_scale = low_scale + (high_scale - low_scale) * i / (steps - 1)
        image = clip_slider.generate(prompt, 
                                     #guidance_scale=guidance_scale, 
                                     scale=cur_scale,  seed=seed, num_inference_steps=steps, avg_diff=avg_diff) 
        images.append(image)
    canvas = Image.new('RGB', (256*interm_steps, 256))
    for i, im in enumerate(images):
        canvas.paste(im.resize((256,256)), (256 * i, 0))

    comma_concepts_x = f"{slider_x[1]}, {slider_x[0]}"

    avg_diff_x = avg_diff.cpu()
  
    return gr.update(label=comma_concepts_x, interactive=True, value=scale), x_concept_1, x_concept_2, avg_diff_x, export_to_gif(images, "clip.gif", fps=5), canvas

@spaces.GPU
def update_scales(x,prompt,seed, steps, interm_steps, guidance_scale,
                  avg_diff_x, 
                  img2img_type = None, img = None,
                  controlnet_scale= None, ip_adapter_scale=None,):
    print("Hola", x)
    avg_diff = avg_diff_x.cuda()

    # for spectrum generation
    images = []

    high_scale = x
    low_scale = -1 * x
    
    if img2img_type=="controlnet canny" and img is not None:
        control_img = process_controlnet_img(img)
        image = t5_slider_controlnet.generate(prompt, guidance_scale=guidance_scale, image=control_img, controlnet_conditioning_scale =controlnet_scale, scale=x, seed=seed, num_inference_steps=steps, avg_diff=avg_diff) 
    elif img2img_type=="ip adapter" and img is not None:
        image = clip_slider.generate(prompt, guidance_scale=guidance_scale, ip_adapter_image=img, scale=x,seed=seed, num_inference_steps=steps, avg_diff=avg_diff) 
    else:
        for i in range(interm_steps):
            cur_scale = low_scale + (high_scale - low_scale) * i / (steps - 1)
            image = clip_slider.generate(prompt, 
                                         #guidance_scale=guidance_scale, 
                                         scale=cur_scale,  seed=seed, num_inference_steps=steps, avg_diff=avg_diff) 
            images.append(image)
        canvas = Image.new('RGB', (256*interm_steps, 256))
        for i, im in enumerate(images):
            canvas.paste(im.resize((256,256)), (256 * i, 0))
    return export_to_gif(images, "clip.gif", fps=5), canvas


def reset_recalc_directions():
    return True

css = '''
#group {
    position: relative;
    width: 600px; /* Increased width */
    height: 600px; /* Increased height */
    margin-bottom: 20px;
    background-color: white;
}
#x {
    position: absolute;
    bottom: 20px; /* Moved further down */
    left: 30px; /* Adjusted left margin */
    width: 540px; /* Increased width to match the new container size */
}
#y {
    position: absolute;
    bottom: 200px; /* Increased bottom margin to ensure proper spacing from #x */
    left: 20px; /* Adjusted left margin */
    width: 540px; /* Increased width to match the new container size */
    transform: rotate(-90deg);
    transform-origin: left bottom;
}
#image_out {
    position: absolute;
    width: 80%; /* Adjust width as needed */
    right: 10px;
    top: 10px; /* Increased top margin to clear space occupied by #x */
}
'''
intro = """
<div style="display: flex;align-items: center;justify-content: center">
    <img src="https://huggingface.co/spaces/LatentNavigation/latentnavigation-flux/resolve/main/Group 4-16.png" width="100" style="display: inline-block">
    <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">Latent Navigation</h1>
</div>
<div style="display: flex;align-items: center;justify-content: center">
    <h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Exploring CLIP text space with FLUX.1 schnell 🪐</h3>
</div>
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
    <a href="https://github.com/linoytsaban/semantic-sliders" target="_blank">code</a>
     | 
    <a href="https://huggingface.co/spaces/LatentNavigation/latentnavigation-flux?duplicate=true" target="_blank" style="
        display: inline-block;
    ">
    <img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a>
</p>
"""
with gr.Blocks() as demo:

    gr.HTML(intro)
    
    x_concept_1 = gr.State("")
    x_concept_2 = gr.State("")
    # y_concept_1 = gr.State("")
    # y_concept_2 = gr.State("")

    avg_diff_x = gr.State()
    #avg_diff_y = gr.State()

    recalc_directions = gr.State(False)
    
    #with gr.Tab("text2image"):
    with gr.Row():
        with gr.Column():
            with gr.Row():
                concept_1 = gr.Textbox(label="A concept to compare")
                concept_2 = gr.Textbox(label="Concept to compare")
            #slider_x = gr.Dropdown(label="Slider concept range", allow_custom_value=True, multiselect=True, max_choices=2)
            #slider_y = gr.Dropdown(label="Slider Y concept range", allow_custom_value=True, multiselect=True, max_choices=2)
            prompt = gr.Textbox(label="Prompt")
            x = gr.Slider(minimum=0, value=1.25, step=0.1, maximum=2.5, info="the strength to scale in each direction")
            submit = gr.Button("find directions")
        with gr.Column():
            with gr.Group(elem_id="group"):
              #y = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
              output_image = gr.Image(elem_id="image_out")
            image_seq = gr.Image()
            # with gr.Row():
            #     generate_butt = gr.Button("generate")
    
    with gr.Accordion(label="advanced options", open=False):
        iterations = gr.Slider(label = "num iterations", minimum=0, value=200, maximum=400)
        steps = gr.Slider(label = "num inference steps", minimum=1, value=4, maximum=10)
        interm_steps = gr.Slider(label = "num of intermediate images", minimum=1, value=5, maximum=65)
        guidance_scale = gr.Slider(
                label="Guidance scale",
                minimum=0.1,
                maximum=10.0,
                step=0.1,
                value=5,
            )

        seed  = gr.Slider(minimum=0, maximum=np.iinfo(np.int32).max, label="Seed", interactive=True, randomize=True)
        
       
    # with gr.Tab(label="image2image"):
    #     with gr.Row():
    #         with gr.Column():
    #             image = gr.ImageEditor(type="pil", image_mode="L", crop_size=(512, 512))
    #             slider_x_a = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
    #             slider_y_a = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
    #             img2img_type = gr.Radio(["controlnet canny", "ip adapter"], label="", info="", visible=False, value="controlnet canny")
    #             prompt_a = gr.Textbox(label="Prompt")
    #             submit_a = gr.Button("Submit")
    #         with gr.Column():
    #             with gr.Group(elem_id="group"):
    #               x_a = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="x", interactive=False)
    #               y_a = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
    #               output_image_a = gr.Image(elem_id="image_out")
    #             with gr.Row():
    #                 generate_butt_a = gr.Button("generate")
        
    #     with gr.Accordion(label="advanced options", open=False):
    #         iterations_a = gr.Slider(label = "num iterations", minimum=0, value=200, maximum=300)
    #         steps_a = gr.Slider(label = "num inference steps", minimum=1, value=8, maximum=30)
    #         guidance_scale_a = gr.Slider(
    #                 label="Guidance scale",
    #                 minimum=0.1,
    #                 maximum=10.0,
    #                 step=0.1,
    #                 value=5,
    #             )
    #         controlnet_conditioning_scale = gr.Slider(
    #                 label="controlnet conditioning scale",
    #                 minimum=0.5,
    #                 maximum=5.0,
    #                 step=0.1,
    #                 value=0.7,
    #             )
    #         ip_adapter_scale = gr.Slider(
    #                 label="ip adapter scale",
    #                 minimum=0.5,
    #                 maximum=5.0,
    #                 step=0.1,
    #                 value=0.8,
    #                 visible=False
    #             )
    #         seed_a  = gr.Slider(minimum=0, maximum=np.iinfo(np.int32).max, label="Seed", interactive=True, randomize=True)
        
    # submit.click(fn=generate,
    #                  inputs=[slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y],
    #                  outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image])
    submit.click(fn=generate,
                     inputs=[concept_1, concept_2, x, prompt, seed, recalc_directions, iterations, steps, interm_steps, guidance_scale, x_concept_1, x_concept_2, avg_diff_x],
                     outputs=[x, x_concept_1, x_concept_2, avg_diff_x, output_image, image_seq])

    iterations.change(fn=reset_recalc_directions, outputs=[recalc_directions])
    seed.change(fn=reset_recalc_directions, outputs=[recalc_directions])
    x.release(fn=update_scales, inputs=[x, prompt, seed, steps, interm_steps, guidance_scale, avg_diff_x], outputs=[output_image, image_seq], trigger_mode='always_last')
    # generate_butt_a.click(fn=update_scales, inputs=[x_a,y_a, prompt_a, seed_a, steps_a, guidance_scale_a, avg_diff_x, avg_diff_y, img2img_type, image, controlnet_conditioning_scale, ip_adapter_scale], outputs=[output_image_a])
    # submit_a.click(fn=generate,
    #                  inputs=[slider_x_a, slider_y_a, prompt_a, seed_a, iterations_a, steps_a, guidance_scale_a, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, img2img_type, image, controlnet_conditioning_scale, ip_adapter_scale],
    #                  outputs=[x_a, y_a, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image_a])

        
if __name__ == "__main__":
    demo.launch()