linoyts HF staff commited on
Commit
b98069b
1 Parent(s): 545388a

revert taking out avg_diff from attributes part 2

Browse files
Files changed (1) hide show
  1. app.py +16 -19
app.py CHANGED
@@ -7,34 +7,34 @@ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
7
 
8
  flash_pipe = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash").to("cuda", torch.float16)
9
  flash_pipe.scheduler = EulerDiscreteScheduler.from_config(flash_pipe.scheduler.config)
10
- clip_slider = CLIPSliderXL(flash_pipe, device=torch.device("cuda"), iterations=100)
11
 
12
  @spaces.GPU
13
- def generate(slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y):
14
 
15
  # check if avg diff for directions need to be re-calculated
16
  if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
17
- avg_diff_x = clip_slider.find_latent_direction(slider_x[0], slider_x[1])
18
  x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
 
 
19
  if not sorted(slider_y) == sorted([y_concept_1, y_concept_2]):
20
- avg_diff_y = clip_slider.find_latent_direction(slider_y[0], slider_y[1])
21
  y_concept_1, y_concept_2 = slider_y[0], slider_y[1]
22
 
23
  comma_concepts_x = ', '.join(slider_x)
24
  comma_concepts_y = ', '.join(slider_y)
25
 
26
- image = clip_slider.generate(prompt, scale=0, scale_2nd=0, num_inference_steps=8, avg_diff=avg_diff_x, avg_diff_2nd=avg_diff_y)
27
 
28
- return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, image
29
 
30
- @spaces.GPU
31
- def update_x(x,y,prompt, avg_diff_x, avg_diff_y):
32
- image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8, avg_diff=avg_diff_x, avg_diff_2nd=avg_diff_y)
33
  return image
34
 
35
- @spaces.GPU
36
- def update_y(x,y,prompt, avg_diff_x, avg_diff_y):
37
- image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8, avg_diff=avg_diff_x, avg_diff_2nd=avg_diff_y)
38
  return image
39
 
40
  css = '''
@@ -67,9 +67,6 @@ with gr.Blocks(css=css) as demo:
67
  x_concept_2 = gr.State("")
68
  y_concept_1 = gr.State("")
69
  y_concept_2 = gr.State("")
70
-
71
- avg_diff_x = gr.State(None)
72
- avg_diff_y = gr.State(None)
73
 
74
  with gr.Row():
75
  with gr.Column():
@@ -83,10 +80,10 @@ with gr.Blocks(css=css) as demo:
83
  output_image = gr.Image(elem_id="image_out")
84
 
85
  submit.click(fn=generate,
86
- inputs=[slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y],
87
- outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image])
88
- x.change(fn=update_x, inputs=[x,y, prompt, avg_diff_x, avg_diff_y], outputs=[output_image])
89
- y.change(fn=update_y, inputs=[x,y, prompt, avg_diff_x, avg_diff_y], outputs=[output_image])
90
 
91
  if __name__ == "__main__":
92
  demo.launch()
 
7
 
8
  flash_pipe = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash").to("cuda", torch.float16)
9
  flash_pipe.scheduler = EulerDiscreteScheduler.from_config(flash_pipe.scheduler.config)
10
+ clip_slider = CLIPSliderXL(flash_pipe, device=torch.device("cuda"), iterations=150)
11
 
12
  @spaces.GPU
13
+ def generate(slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2):
14
 
15
  # check if avg diff for directions need to be re-calculated
16
  if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
17
+ clip_slider.avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1])
18
  x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
19
+ print("clip_slider.avg_diff[0]", clip_slider.avg_diff[0])
20
+ print("clip_slider.avg_diff[1]", clip_slider.avg_diff[1])
21
  if not sorted(slider_y) == sorted([y_concept_1, y_concept_2]):
22
+ clip_slider.avg_diff_2nd = clip_slider.find_latent_direction(slider_y[0], slider_y[1])
23
  y_concept_1, y_concept_2 = slider_y[0], slider_y[1]
24
 
25
  comma_concepts_x = ', '.join(slider_x)
26
  comma_concepts_y = ', '.join(slider_y)
27
 
28
+ image = clip_slider.generate(prompt, scale=0, scale_2nd=0, num_inference_steps=8)
29
 
30
+ return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, image
31
 
32
+ def update_x(x,y,prompt):
33
+ image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
 
34
  return image
35
 
36
+ def update_y(x,y,prompt):
37
+ image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
 
38
  return image
39
 
40
  css = '''
 
67
  x_concept_2 = gr.State("")
68
  y_concept_1 = gr.State("")
69
  y_concept_2 = gr.State("")
 
 
 
70
 
71
  with gr.Row():
72
  with gr.Column():
 
80
  output_image = gr.Image(elem_id="image_out")
81
 
82
  submit.click(fn=generate,
83
+ inputs=[slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2],
84
+ outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, output_image])
85
+ x.change(fn=update_x, inputs=[x,y, prompt], outputs=[output_image])
86
+ y.change(fn=update_y, inputs=[x,y, prompt], outputs=[output_image])
87
 
88
  if __name__ == "__main__":
89
  demo.launch()