File size: 5,237 Bytes
cd2465c
3051f7b
f217e4d
 
9ed65fe
2d30f4b
f217e4d
0f97695
 
666c6de
79bd7e1
cd2465c
3051f7b
79bd7e1
b1c5569
 
 
2d30f4b
f217e4d
 
6c3f8be
f217e4d
d5a8945
f217e4d
6c3f8be
f217e4d
2d30f4b
 
 
79bd7e1
2d30f4b
 
cd2465c
 
 
0c5d517
 
 
 
f217e4d
b1c5569
f217e4d
64b9ad0
79bd7e1
6ef419c
 
79bd7e1
4b0fbd1
cd2465c
64b9ad0
79bd7e1
6ef419c
 
79bd7e1
4b0fbd1
cd2465c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbab3de
cd2465c
 
 
 
 
 
 
f217e4d
 
 
 
 
d5a8945
b1c5569
 
 
 
f217e4d
50d6862
 
 
 
 
 
 
 
 
 
6c3f8be
 
79bd7e1
 
0f97695
50d6862
cd2465c
79bd7e1
b1c5569
102409f
 
50d6862
cd2465c
50d6862
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import spaces
import torch
from clip_slider_pipeline import CLIPSliderXL
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler,  AutoencoderKL
import time

#vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
flash_pipe = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash").to("cuda", torch.float16)
flash_pipe.scheduler = EulerDiscreteScheduler.from_config(flash_pipe.scheduler.config)
clip_slider = CLIPSliderXL(flash_pipe, device=torch.device("cuda"))

@spaces.GPU
def generate(slider_x, slider_y, prompt, iterations, steps, 
             x_concept_1, x_concept_2, y_concept_1, y_concept_2, 
             avg_diff_x_1, avg_diff_x_2,
             avg_diff_y_1, avg_diff_y_2):
    start_time = time.time()
    # check if avg diff for directions need to be re-calculated
    if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
        avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1], num_iterations=iterations)
        x_concept_1, x_concept_2 = slider_x[0], slider_x[1]

    if not sorted(slider_y) == sorted([y_concept_1, y_concept_2]):
        avg_diff_2nd = clip_slider.find_latent_direction(slider_y[0], slider_y[1], num_iterations=iterations)
        y_concept_1, y_concept_2 = slider_y[0], slider_y[1]
    end_time = time.time()
    print(f"direction time: {end_time - start_time:.2f} ms")
    start_time = time.time()
    image = clip_slider.generate(prompt, scale=0, scale_2nd=0, num_inference_steps=steps, avg_diff=avg_diff, avg_diff_2nd=avg_diff_2nd)
    end_time = time.time()
    print(f"generation time: {end_time - start_time:.2f} ms")
    comma_concepts_x = ', '.join(slider_x)
    comma_concepts_y = ', '.join(slider_y)

    avg_diff_x_1 = avg_diff[0].cpu()
    avg_diff_x_2 = avg_diff[1].cpu()
    avg_diff_y_1 = avg_diff_2nd[0].cpu()
    avg_diff_y_2 = avg_diff_2nd[1].cpu()
  
    return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, image

@spaces.GPU
def update_x(x,y,prompt, steps, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
    avg_diff = [avg_diff_x_1.cuda(), avg_diff_x_2.cuda()] 
    avg_diff_2nd = [avg_diff_y_1.cuda(), avg_diff_y_2.cuda()]
    image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd) 
    return image

@spaces.GPU
def update_y(x,y,prompt, steps, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
    avg_diff = [avg_diff_x_1.cuda(), avg_diff_x_2.cuda()] 
    avg_diff_2nd = [avg_diff_y_1.cuda(), avg_diff_y_2.cuda()]
    image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd) 
    return image
  
css = '''
#group {
    position: relative;
    width: 420px;
    height: 420px;
    margin-bottom: 20px;
    background-color: white
}
#x {
    position: absolute;
    bottom: 0;
    left: 25px;
    width: 400px;
}
#y {
    position: absolute;
    bottom: 20px;
    left: 67px;
    width: 400px;
    transform: rotate(-90deg);
    transform-origin: left bottom;
}
#image_out{position:absolute; width: 80%; right: 10px; top: 40px}
'''
with gr.Blocks(css=css) as demo:
    
    x_concept_1 = gr.State("")
    x_concept_2 = gr.State("")
    y_concept_1 = gr.State("")
    y_concept_2 = gr.State("")

    avg_diff_x_1 = gr.State()
    avg_diff_x_2 = gr.State()
    avg_diff_y_1 = gr.State()
    avg_diff_y_2 = gr.State()
    
    with gr.Row():
        with gr.Column():
            slider_x = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
            slider_y = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
            prompt = gr.Textbox(label="Prompt")
            submit = gr.Button("Submit")
        with gr.Group(elem_id="group"):
          x = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="x", interactive=False)
          y = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
          output_image = gr.Image(elem_id="image_out")
    
    with gr.Accordion(label="advanced options", open=False):
        iterations = gr.Slider(label = "num iterations", minimum=0, value=100, maximum=300)
        steps = gr.Slider(label = "num inference steps", minimum=1, value=8, maximum=30)
        
    
    submit.click(fn=generate,
                 inputs=[slider_x, slider_y, prompt, iterations, steps, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2],
                 outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, output_image])
    x.change(fn=update_x, inputs=[x,y, prompt, steps, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
    y.change(fn=update_y, inputs=[x,y, prompt, steps, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])

if __name__ == "__main__":
    demo.launch()