animrods commited on
Commit
527b2cf
1 Parent(s): 420808c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -7
app.py CHANGED
@@ -8,6 +8,8 @@ import spaces
8
  from PIL import Image
9
  hf_token = os.environ.get("HF_TOKEN")
10
  from diffusers import AutoPipelineForText2Image
 
 
11
 
12
 
13
  device = "cuda" #if torch.cuda.is_available() else "cpu"
@@ -19,15 +21,19 @@ pipe.to(device)
19
  MAX_SEED = np.iinfo(np.int32).max
20
 
21
  @spaces.GPU(enable_queue=True)
22
- def predict(prompt, ip_adapter_images, ip_adapter_scale=0.5, negative_prompt="", seed=100, randomize_seed=False, center_crop=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=50, progress=gr.Progress(track_tqdm=True)):
23
  if randomize_seed:
24
  seed = random.randint(0, MAX_SEED)
25
 
26
- ip_adapter_images = [Image.open(image) for image in ip_adapter_images]
 
 
 
 
27
 
28
- # Optionally resize images if center crop is not selected
29
- if not center_crop:
30
- ip_adapter_images = [image.resize((224, 224)) for image in ip_adapter_images]
31
 
32
  generator = torch.Generator(device="cuda").manual_seed(seed)
33
  pipe.set_ip_adapter_scale([ip_adapter_scale])
@@ -72,7 +78,16 @@ with gr.Blocks(css=css) as demo:
72
  with gr.Row():
73
  with gr.Column():
74
  # ip_adapter_images = gr.Gallery(label="Input Images", elem_id="image-gallery").style(grid=[2], preview=True)
75
- ip_adapter_images = gr.Gallery(label="Input Images", elem_id="image-gallery", show_label=True)#.style(grid=[2])
 
 
 
 
 
 
 
 
 
76
 
77
  ip_adapter_scale = gr.Slider(
78
  label="Image Input Scale",
@@ -156,6 +171,8 @@ with gr.Blocks(css=css) as demo:
156
  step=1,
157
  value=25,
158
  )
 
 
159
 
160
 
161
  # gr.Examples(
@@ -169,7 +186,7 @@ with gr.Blocks(css=css) as demo:
169
  gr.on(
170
  triggers=[run_button.click, prompt.submit],
171
  fn=predict,
172
- inputs=[prompt, ip_adapter_images, ip_adapter_scale, negative_prompt, seed, randomize_seed, center_crop, width, height, guidance_scale, num_inference_steps],
173
  outputs=[result, seed]
174
  )
175
 
 
8
  from PIL import Image
9
  hf_token = os.environ.get("HF_TOKEN")
10
  from diffusers import AutoPipelineForText2Image
11
+ from diffusers.utils import load_image
12
+
13
 
14
 
15
  device = "cuda" #if torch.cuda.is_available() else "cpu"
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
 
23
  @spaces.GPU(enable_queue=True)
24
+ def predict(prompt, files, ip_adapter_scale=0.5, negative_prompt="", seed=100, randomize_seed=False, center_crop=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=50, progress=gr.Progress(track_tqdm=True)):
25
  if randomize_seed:
26
  seed = random.randint(0, MAX_SEED)
27
 
28
+ ip_adapter_images = []
29
+ for img in upload_images:
30
+ ip_adapter_images.append(load_image(img))
31
+
32
+ # ip_adapter_images = [Image.open(image) for image in ip_adapter_images]
33
 
34
+ # # Optionally resize images if center crop is not selected
35
+ # if not center_crop:
36
+ # ip_adapter_images = [image.resize((224, 224)) for image in ip_adapter_images]
37
 
38
  generator = torch.Generator(device="cuda").manual_seed(seed)
39
  pipe.set_ip_adapter_scale([ip_adapter_scale])
 
78
  with gr.Row():
79
  with gr.Column():
80
  # ip_adapter_images = gr.Gallery(label="Input Images", elem_id="image-gallery").style(grid=[2], preview=True)
81
+ # ip_adapter_images = gr.Gallery(label="Input Images", elem_id="image-gallery", show_label=True)#.style(grid=[2])
82
+ ip_adapter_images = gr.Gallery(columns=4, interactive=True, label="Input Images")
83
+
84
+ files = gr.File(
85
+ label="Input Image/s",
86
+ file_types=["image"],
87
+ file_count="multiple"
88
+ )
89
+ uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200)
90
+
91
 
92
  ip_adapter_scale = gr.Slider(
93
  label="Image Input Scale",
 
171
  step=1,
172
  value=25,
173
  )
174
+ files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
175
+ remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
176
 
177
 
178
  # gr.Examples(
 
186
  gr.on(
187
  triggers=[run_button.click, prompt.submit],
188
  fn=predict,
189
+ inputs=[prompt, files, ip_adapter_scale, negative_prompt, seed, randomize_seed, center_crop, width, height, guidance_scale, num_inference_steps],
190
  outputs=[result, seed]
191
  )
192