chaoxu commited on
Commit
83bd11e
1 Parent(s): 5634871

Predict normal for multiviews by controlnet

Browse files
Files changed (1) hide show
  1. gradio_app.py +44 -7
gradio_app.py CHANGED
@@ -1,10 +1,11 @@
1
  import os
 
2
  import torch
3
  import fire
4
  import gradio as gr
5
  from PIL import Image
6
  from functools import partial
7
- from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
8
  from share_btn import community_icon_html, loading_icon_html, share_js
9
 
10
  import cv2
@@ -122,7 +123,7 @@ def save_image(image, original_image):
122
  os.remove(out_path)
123
  os.remove(in_path)
124
 
125
- def gen_multiview(pipeline, predictor, input_image, scale_slider, steps_slider, seed, output_processing=False, original_image=None):
126
  seed = int(seed)
127
  torch.manual_seed(seed)
128
  image = pipeline(input_image,
@@ -131,6 +132,15 @@ def gen_multiview(pipeline, predictor, input_image, scale_slider, steps_slider,
131
  generator=torch.Generator(pipeline.device).manual_seed(seed)).images[0]
132
  side_len = image.width//2
133
  subimages = [image.crop((x, y, x + side_len, y+side_len)) for y in range(0, image.height, side_len) for x in range(0, image.width, side_len)]
 
 
 
 
 
 
 
 
 
134
  if "Background Removal" in output_processing:
135
  out_images = []
136
  merged_image = Image.new('RGB', (640, 960))
@@ -142,9 +152,17 @@ def gen_multiview(pipeline, predictor, input_image, scale_slider, steps_slider,
142
  y = (i % 3) * 320
143
  merged_image.paste(sub_image, (x, y))
144
  save_image(merged_image, original_image)
145
- return out_images + [merged_image]
 
 
 
 
 
 
 
 
146
  save_image(image, original_image)
147
- return subimages + [image]
148
 
149
 
150
  def run_demo():
@@ -159,6 +177,14 @@ def run_demo():
159
  )
160
  pipeline.to(f'cuda:{_GPU_ID}')
161
 
 
 
 
 
 
 
 
 
162
  predictor = sam_init()
163
 
164
  custom_theme = gr.themes.Soft(primary_hue="blue").set(
@@ -187,6 +213,8 @@ def run_demo():
187
  label='Examples (click one of the images below to start)',
188
  examples_per_page=10
189
  )
 
 
190
  with gr.Accordion('Advanced options', open=False):
191
  with gr.Row():
192
  with gr.Column():
@@ -213,6 +241,14 @@ def run_demo():
213
  view_4 = gr.Image(interactive=False, height=240, show_label=False)
214
  view_5 = gr.Image(interactive=False, height=240, show_label=False)
215
  view_6 = gr.Image(interactive=False, height=240, show_label=False)
 
 
 
 
 
 
 
 
216
  full_view = gr.Image(visible=False, interactive=False, elem_id="six_view")
217
  with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
218
  community_icon = gr.HTML(community_icon_html)
@@ -227,9 +263,10 @@ def run_demo():
227
  ).success(fn=partial(preprocess, predictor),
228
  inputs=[input_image, input_processing],
229
  outputs=[processed_image_highres, processed_image], queue=True
230
- ).success(fn=partial(gen_multiview, pipeline, predictor),
231
- inputs=[processed_image_highres, scale_slider, steps_slider, seed, output_processing, input_image],
232
- outputs=[view_1, view_2, view_3, view_4, view_5, view_6, full_view], queue=True
 
233
  ).success(show_share_btn, outputs=share_group, queue=False)
234
 
235
  share_button.click(None, [], [], _js=share_js)
 
1
  import os
2
+ import copy
3
  import torch
4
  import fire
5
  import gradio as gr
6
  from PIL import Image
7
  from functools import partial
8
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, ControlNetModel
9
  from share_btn import community_icon_html, loading_icon_html, share_js
10
 
11
  import cv2
 
123
  os.remove(out_path)
124
  os.remove(in_path)
125
 
126
+ def gen_multiview(pipeline, pipeline_normal, predictor, input_image, scale_slider, steps_slider, seed, output_processing=False, original_image=None, out_normal=True):
127
  seed = int(seed)
128
  torch.manual_seed(seed)
129
  image = pipeline(input_image,
 
132
  generator=torch.Generator(pipeline.device).manual_seed(seed)).images[0]
133
  side_len = image.width//2
134
  subimages = [image.crop((x, y, x + side_len, y+side_len)) for y in range(0, image.height, side_len) for x in range(0, image.width, side_len)]
135
+
136
+ # normal images
137
+ out_images_normal = [gr.Image(None) for _ in range(6)]
138
+ if out_normal:
139
+ image_normal = pipeline_normal(input_image, depth_image=image,
140
+ prompt='', guidance_scale=1, num_inference_steps=50, width=640, height=960
141
+ ).images[0]
142
+ subimages_normal = [image_normal.crop((x, y, x + side_len, y+side_len)) for y in range(0, image_normal.height, side_len) for x in range(0, image_normal.width, side_len)]
143
+
144
  if "Background Removal" in output_processing:
145
  out_images = []
146
  merged_image = Image.new('RGB', (640, 960))
 
152
  y = (i % 3) * 320
153
  merged_image.paste(sub_image, (x, y))
154
  save_image(merged_image, original_image)
155
+
156
+ if out_normal:
157
+ out_images_normal = []
158
+ # merged_image_normal = Image.new('RGB', (640, 960))
159
+ for i, sub_image in enumerate(subimages_normal):
160
+ sub_image, _ = preprocess(predictor, sub_image.convert('RGB'), segment=True, rescale=False)
161
+ out_images_normal.append(sub_image)
162
+
163
+ return out_images + [merged_image] + out_images_normal
164
  save_image(image, original_image)
165
+ return subimages + [image] + out_images_normal
166
 
167
 
168
  def run_demo():
 
177
  )
178
  pipeline.to(f'cuda:{_GPU_ID}')
179
 
180
+ normal_pipeline = copy.copy(pipeline)
181
+ controlnet = ControlNetModel.from_pretrained(
182
+ "sudo-ai/controlnet-zp12-normal-gen-v1",
183
+ torch_dtype=torch.float16, use_auth_token=os.environ["HF_TOKEN"]
184
+ )
185
+ normal_pipeline.add_controlnet(controlnet, conditioning_scale=1.0)
186
+ normal_pipeline.to(f'cuda:{_GPU_ID}')
187
+
188
  predictor = sam_init()
189
 
190
  custom_theme = gr.themes.Soft(primary_hue="blue").set(
 
213
  label='Examples (click one of the images below to start)',
214
  examples_per_page=10
215
  )
216
+ with gr.Row():
217
+ out_normal = gr.Checkbox(value=True, label='Predict normal images for generated multiviews', elem_id="out_normal")
218
  with gr.Accordion('Advanced options', open=False):
219
  with gr.Row():
220
  with gr.Column():
 
241
  view_4 = gr.Image(interactive=False, height=240, show_label=False)
242
  view_5 = gr.Image(interactive=False, height=240, show_label=False)
243
  view_6 = gr.Image(interactive=False, height=240, show_label=False)
244
+ with gr.Row():
245
+ norm_1 = gr.Image(interactive=False, height=240, show_label=False)
246
+ norm_2 = gr.Image(interactive=False, height=240, show_label=False)
247
+ norm_3 = gr.Image(interactive=False, height=240, show_label=False)
248
+ with gr.Row():
249
+ norm_4 = gr.Image(interactive=False, height=240, show_label=False)
250
+ norm_5 = gr.Image(interactive=False, height=240, show_label=False)
251
+ norm_6 = gr.Image(interactive=False, height=240, show_label=False)
252
  full_view = gr.Image(visible=False, interactive=False, elem_id="six_view")
253
  with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
254
  community_icon = gr.HTML(community_icon_html)
 
263
  ).success(fn=partial(preprocess, predictor),
264
  inputs=[input_image, input_processing],
265
  outputs=[processed_image_highres, processed_image], queue=True
266
+ ).success(fn=partial(gen_multiview, pipeline, normal_pipeline, predictor),
267
+ inputs=[processed_image_highres, scale_slider, steps_slider, seed, output_processing, input_image, out_normal],
268
+ outputs=[view_1, view_2, view_3, view_4, view_5, view_6, full_view,
269
+ norm_1, norm_2, norm_3, norm_4, norm_5, norm_6], queue=True
270
  ).success(show_share_btn, outputs=share_group, queue=False)
271
 
272
  share_button.click(None, [], [], _js=share_js)