yisol commited on
Commit
595105e
·
1 Parent(s): 3af7a49

update demo code

Browse files
Files changed (2) hide show
  1. app.py +14 -4
  2. src/tryon_pipeline.py +22 -24
app.py CHANGED
@@ -23,7 +23,7 @@ import apply_net
23
  from preprocess.humanparsing.run_parsing import Parsing
24
  from preprocess.openpose.run_openpose import OpenPose
25
  from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
26
-
27
 
28
 
29
  def pil_to_binary_mask(pil_image, threshold=0):
@@ -141,6 +141,8 @@ def start_tryon(dict,garm_img,garment_des,is_checked,denoise_steps,seed):
141
  mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
142
  mask = transforms.ToTensor()(mask)
143
  mask = mask.unsqueeze(0)
 
 
144
 
145
 
146
  human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
@@ -191,7 +193,9 @@ def start_tryon(dict,garm_img,garment_des,is_checked,denoise_steps,seed):
191
  do_classifier_free_guidance=False,
192
  negative_prompt=negative_prompt,
193
  )
194
-
 
 
195
  pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
196
  garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
197
  generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
@@ -213,7 +217,7 @@ def start_tryon(dict,garm_img,garment_des,is_checked,denoise_steps,seed):
213
  ip_adapter_image = garm_img.resize((768,1024)),
214
  guidance_scale=2.0,
215
  )[0]
216
- return images[0]
217
 
218
  garm_list = os.listdir(os.path.join(example_path,"cloth"))
219
  garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
@@ -253,10 +257,16 @@ with image_blocks as demo:
253
  inputs=garm_img,
254
  examples_per_page=8,
255
  examples=garm_list_path)
 
 
 
256
  with gr.Column():
257
  # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
258
  image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
259
 
 
 
 
260
  with gr.Column():
261
  try_button = gr.Button(value="Try-on")
262
  with gr.Accordion(label="Advanced Settings", open=False):
@@ -265,7 +275,7 @@ with image_blocks as demo:
265
  seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
266
 
267
 
268
- try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked, denoise_steps, seed], outputs=[image_out], api_name='tryon')
269
 
270
 
271
 
 
23
  from preprocess.humanparsing.run_parsing import Parsing
24
  from preprocess.openpose.run_openpose import OpenPose
25
  from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
26
+ from torchvision.tranfsorms.functional import to_pil_image
27
 
28
 
29
  def pil_to_binary_mask(pil_image, threshold=0):
 
141
  mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
142
  mask = transforms.ToTensor()(mask)
143
  mask = mask.unsqueeze(0)
144
+ mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
145
+ mask_gray = to_pil_image((mask_gray+1.0)/2.0)
146
 
147
 
148
  human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
 
193
  do_classifier_free_guidance=False,
194
  negative_prompt=negative_prompt,
195
  )
196
+
197
+
198
+
199
  pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
200
  garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
201
  generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
 
217
  ip_adapter_image = garm_img.resize((768,1024)),
218
  guidance_scale=2.0,
219
  )[0]
220
+ return images[0], mask_gray
221
 
222
  garm_list = os.listdir(os.path.join(example_path,"cloth"))
223
  garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
 
257
  inputs=garm_img,
258
  examples_per_page=8,
259
  examples=garm_list_path)
260
+ with gr.Column():
261
+ # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
262
+ masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
263
  with gr.Column():
264
  # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
265
  image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
266
 
267
+
268
+
269
+
270
  with gr.Column():
271
  try_button = gr.Button(value="Try-on")
272
  with gr.Accordion(label="Advanced Settings", open=False):
 
275
  seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
276
 
277
 
278
+ try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
279
 
280
 
281
 
src/tryon_pipeline.py CHANGED
@@ -480,36 +480,30 @@ class StableDiffusionXLInpaintPipeline(
480
 
481
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
482
  def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
483
- if not isinstance(ip_adapter_image, list):
484
- ip_adapter_image = [ip_adapter_image]
485
 
486
  # if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
487
  # raise ValueError(
488
  # f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
489
  # )
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
- image_embeds = []
492
- # print(ip_adapter_image.shape)
493
- for single_ip_adapter_image in ip_adapter_image:
494
- # print(single_ip_adapter_image.shape)
495
- # ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
496
- output_hidden_state = not isinstance(self.unet.encoder_hid_proj, ImageProjection)
497
- # print(output_hidden_state)
498
- single_image_embeds, single_negative_image_embeds = self.encode_image(
499
- single_ip_adapter_image, device, 1, output_hidden_state
500
- )
501
- # print(single_image_embeds.shape)
502
- # single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
503
- # single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
504
- # print(single_image_embeds.shape)
505
- if self.do_classifier_free_guidance:
506
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
507
- single_image_embeds = single_image_embeds.to(device)
508
-
509
- image_embeds.append(single_image_embeds)
510
 
511
  return image_embeds
512
 
 
513
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
514
  def encode_prompt(
515
  self,
@@ -1724,8 +1718,10 @@ class StableDiffusionXLInpaintPipeline(
1724
  image_embeds = self.prepare_ip_adapter_image_embeds(
1725
  ip_adapter_image, device, batch_size * num_images_per_prompt
1726
  )
1727
- # print("a")
1728
- # print(image_embeds[0].shape)
 
 
1729
 
1730
  # 11. Denoising loop
1731
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -1759,6 +1755,8 @@ class StableDiffusionXLInpaintPipeline(
1759
  guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1760
  ).to(device=device, dtype=latents.dtype)
1761
 
 
 
1762
  self._num_timesteps = len(timesteps)
1763
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1764
  for i, t in enumerate(timesteps):
@@ -1781,7 +1779,7 @@ class StableDiffusionXLInpaintPipeline(
1781
  # predict the noise residual
1782
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1783
  if ip_adapter_image is not None:
1784
- added_cond_kwargs["image_embeds"] = image_embeds[0]
1785
  # down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
1786
  down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
1787
  # print(type(reference_features))
 
480
 
481
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
482
  def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
483
+ # if not isinstance(ip_adapter_image, list):
484
+ # ip_adapter_image = [ip_adapter_image]
485
 
486
  # if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
487
  # raise ValueError(
488
  # f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
489
  # )
490
+ output_hidden_state = not isinstance(self.unet.encoder_hid_proj, ImageProjection)
491
+ # print(output_hidden_state)
492
+ image_embeds, negative_image_embeds = self.encode_image(
493
+ ip_adapter_image, device, 1, output_hidden_state
494
+ )
495
+ # print(single_image_embeds.shape)
496
+ # single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
497
+ # single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
498
+ # print(single_image_embeds.shape)
499
+ if self.do_classifier_free_guidance:
500
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
501
+ image_embeds = image_embeds.to(device)
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
  return image_embeds
505
 
506
+
507
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
508
  def encode_prompt(
509
  self,
 
1718
  image_embeds = self.prepare_ip_adapter_image_embeds(
1719
  ip_adapter_image, device, batch_size * num_images_per_prompt
1720
  )
1721
+
1722
+ #project outside for loop
1723
+ image_embeds = unet.encoder_hid_proj(image_embeds).to(prompt_embeds.dtype)
1724
+
1725
 
1726
  # 11. Denoising loop
1727
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
 
1755
  guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1756
  ).to(device=device, dtype=latents.dtype)
1757
 
1758
+
1759
+
1760
  self._num_timesteps = len(timesteps)
1761
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1762
  for i, t in enumerate(timesteps):
 
1779
  # predict the noise residual
1780
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1781
  if ip_adapter_image is not None:
1782
+ added_cond_kwargs["image_embeds"] = image_embeds
1783
  # down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
1784
  down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
1785
  # print(type(reference_features))