File size: 4,957 Bytes
cd2465c
3051f7b
f217e4d
 
9ed65fe
f217e4d
0f97695
 
666c6de
79bd7e1
cd2465c
3051f7b
79bd7e1
b1c5569
 
 
f217e4d
 
 
79bd7e1
f217e4d
d5a8945
f217e4d
79bd7e1
f217e4d
 
79bd7e1
cd2465c
 
 
0c5d517
 
 
 
f217e4d
b1c5569
f217e4d
64b9ad0
79bd7e1
6ef419c
 
79bd7e1
4b0fbd1
cd2465c
64b9ad0
79bd7e1
6ef419c
 
79bd7e1
4b0fbd1
cd2465c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbab3de
cd2465c
 
 
 
 
 
 
f217e4d
 
 
 
 
d5a8945
b1c5569
 
 
 
f217e4d
50d6862
 
 
 
 
 
 
 
 
 
0f97695
79bd7e1
 
0f97695
50d6862
cd2465c
79bd7e1
b1c5569
 
 
50d6862
cd2465c
50d6862
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import spaces
import torch
from clip_slider_pipeline import CLIPSliderXL
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler,  AutoencoderKL

#vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
flash_pipe = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash").to("cuda", torch.float16)
flash_pipe.scheduler = EulerDiscreteScheduler.from_config(flash_pipe.scheduler.config)
clip_slider = CLIPSliderXL(flash_pipe, device=torch.device("cuda"))

@spaces.GPU
def generate(slider_x, slider_y, prompt, iterations, steps, 
             x_concept_1, x_concept_2, y_concept_1, y_concept_2, 
             avg_diff_x_1, avg_diff_x_2,
             avg_diff_y_1, avg_diff_y_2):

    # check if avg diff for directions need to be re-calculated
    if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
        avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1], iterations=iterations)
        x_concept_1, x_concept_2 = slider_x[0], slider_x[1]

    if not sorted(slider_y) == sorted([y_concept_1, y_concept_2]):
        avg_diff_2nd = clip_slider.find_latent_direction(slider_y[0], slider_y[1], iterations=iterations)
        y_concept_1, y_concept_2 = slider_y[0], slider_y[1]
    
    image = clip_slider.generate(prompt, scale=0, scale_2nd=0, num_inference_steps=steps, avg_diff=avg_diff, avg_diff_2nd=avg_diff_2nd)
    comma_concepts_x = ', '.join(slider_x)
    comma_concepts_y = ', '.join(slider_y)

    avg_diff_x_1 = avg_diff[0].cpu()
    avg_diff_x_2 = avg_diff[1].cpu()
    avg_diff_y_1 = avg_diff_2nd[0].cpu()
    avg_diff_y_2 = avg_diff_2nd[1].cpu()
  
    return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, image

@spaces.GPU
def update_x(x,y,prompt, steps, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
    avg_diff = [avg_diff_x_1.cuda(), avg_diff_x_2.cuda()] 
    avg_diff_2nd = [avg_diff_y_1.cuda(), avg_diff_y_2.cuda()]
    image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd) 
    return image

@spaces.GPU
def update_y(x,y,prompt, steps, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2):
    avg_diff = [avg_diff_x_1.cuda(), avg_diff_x_2.cuda()] 
    avg_diff_2nd = [avg_diff_y_1.cuda(), avg_diff_y_2.cuda()]
    image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd) 
    return image
  
css = '''
#group {
    position: relative;
    width: 420px;
    height: 420px;
    margin-bottom: 20px;
    background-color: white
}
#x {
    position: absolute;
    bottom: 0;
    left: 25px;
    width: 400px;
}
#y {
    position: absolute;
    bottom: 20px;
    left: 67px;
    width: 400px;
    transform: rotate(-90deg);
    transform-origin: left bottom;
}
#image_out{position:absolute; width: 80%; right: 10px; top: 40px}
'''
with gr.Blocks(css=css) as demo:
    
    x_concept_1 = gr.State("")
    x_concept_2 = gr.State("")
    y_concept_1 = gr.State("")
    y_concept_2 = gr.State("")

    avg_diff_x_1 = gr.State()
    avg_diff_x_2 = gr.State()
    avg_diff_y_1 = gr.State()
    avg_diff_y_2 = gr.State()
    
    with gr.Row():
        with gr.Column():
            slider_x = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
            slider_y = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
            prompt = gr.Textbox(label="Prompt")
            submit = gr.Button("Submit")
        with gr.Group(elem_id="group"):
          x = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="x", interactive=False)
          y = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
          output_image = gr.Image(elem_id="image_out")
    with gr.Accordion(label="advanced options"):
        iterations = gr.Slider(label = "num iterations", minimum=0, value=100, maximum=300)
        steps = gr.Slider(label = "num inference steps", minimum=1, value=8, maximum=30)
        
    
    submit.click(fn=generate,
                 inputs=[slider_x, slider_y, prompt, iterations, steps, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2],
                 outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, output_image])
    x.change(fn=update_x, inputs=[x,y, prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
    y.change(fn=update_y, inputs=[x,y, prompt, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])

if __name__ == "__main__":
    demo.launch()