aiqcamp commited on
Commit
46c3f57
·
verified ·
1 Parent(s): 798e1fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +390 -155
app.py CHANGED
@@ -1,5 +1,5 @@
1
- import spaces
2
  import gradio as gr
 
3
  from PIL import Image
4
  from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
5
  from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
@@ -16,7 +16,6 @@ from typing import List
16
  import torch
17
  import os
18
  from transformers import AutoTokenizer
19
-
20
  import numpy as np
21
  from utils_mask import get_mask_location
22
  from torchvision import transforms
@@ -26,6 +25,7 @@ from preprocess.openpose.run_openpose import OpenPose
26
  from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
27
  from torchvision.transforms.functional import to_pil_image
28
 
 
29
  def pil_to_binary_mask(pil_image, threshold=0):
30
  np_image = np.array(pil_image)
31
  grayscale_image = Image.fromarray(np_image).convert("L")
@@ -39,53 +39,324 @@ def pil_to_binary_mask(pil_image, threshold=0):
39
  output_mask = Image.fromarray(mask)
40
  return output_mask
41
 
42
- import numpy as np
43
- from PIL import Image
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- def get_mask_location(mode, category, parsing, keypoints):
47
- parsing = np.array(parsing)
48
- mask = np.zeros_like(parsing)
49
 
50
- print(f"Selected category: {category}")
51
- print(f"Parsing shape: {parsing.shape}")
52
- print(f"Unique values in parsing: {np.unique(parsing)}")
 
 
 
53
 
54
- if category == "상의":
55
- # 상의에 해당하는 부분만 마스킹 (상체, 팔)
56
- upper_body = [5, 6, 7]
57
- mask[np.isin(parsing, upper_body)] = 255
58
- print(f"Masking upper body parts: {upper_body}")
59
- elif category == "하의":
60
- # 하의에 해당하는 부분만 마스킹 (하체)
61
- lower_body = [9, 12, 13, 14, 15, 16, 17, 18, 19]
62
- mask[np.isin(parsing, lower_body)] = 255
63
- print(f"Masking lower body parts: {lower_body}")
64
- elif category == "드레스":
65
- # 드레스에 해당하는 부분 마스킹 (상체와 하체)
66
- full_body = [5, 6, 7, 9, 12, 13, 14, 15, 16, 17, 18, 19]
67
- mask[np.isin(parsing, full_body)] = 255
68
- print(f"Masking full body parts: {full_body}")
69
  else:
70
- raise ValueError(f"Unknown category: {category}")
71
-
72
- print(f"Mask shape: {mask.shape}, Unique values in mask: {np.unique(mask)}")
73
- print(f"Number of masked pixels: {np.sum(mask == 255)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # 마스크 시각화를 위한 코드 추가
76
- import matplotlib.pyplot as plt
77
- plt.figure(figsize=(10, 10))
78
- plt.imshow(mask, cmap='gray')
79
- plt.title(f"Mask for {category}")
80
- plt.savefig(f"mask_{category}.png")
81
- plt.close()
82
 
83
- mask_gray = Image.fromarray(mask.astype(np.uint8))
84
- return mask_gray, mask_gray
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
- base_path = 'yisol/IDM-VTON'
89
  example_path = os.path.join(os.path.dirname(__file__), 'example')
90
 
91
  unet = UNet2DConditionModel.from_pretrained(
@@ -128,6 +399,7 @@ vae = AutoencoderKL.from_pretrained(base_path,
128
  torch_dtype=torch.float16,
129
  )
130
 
 
131
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
132
  base_path,
133
  subfolder="unet_encoder",
@@ -165,17 +437,17 @@ pipe = TryonPipeline.from_pretrained(
165
  )
166
  pipe.unet_encoder = UNet_Encoder
167
 
168
- @spaces.GPU
169
- def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, category):
170
  device = "cuda"
171
-
172
  openpose_model.preprocessor.body_estimation.model.to(device)
173
  pipe.to(device)
174
  pipe.unet_encoder.to(device)
175
 
176
- garm_img = garm_img.convert("RGB").resize((768,1024))
177
- human_img_orig = dict["background"].convert("RGB")
178
-
179
  if is_checked_crop:
180
  width, height = human_img_orig.size
181
  target_width = int(min(width, height * (3 / 4)))
@@ -191,72 +463,36 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
191
  human_img = human_img_orig.resize((768,1024))
192
 
193
 
194
- status_message = ""
195
  if is_checked:
196
- try:
197
- print(f"Processing category: {category}")
198
- keypoints = openpose_model(human_img.resize((384,512)))
199
- model_parse, _ = parsing_model(human_img.resize((384,512)))
200
-
201
- # 파싱 모델의 출력 확인
202
- print(f"Parsing model output shape: {model_parse.shape}")
203
- print(f"Unique values in parsing model output: {np.unique(model_parse)}")
204
-
205
- mask, mask_gray = get_mask_location('hd', category, model_parse, keypoints)
206
-
207
- # 마스크 확인 및 시각화
208
- mask_array = np.array(mask)
209
- print(f"Mask shape after get_mask_location: {mask_array.shape}")
210
- print(f"Unique values in mask after get_mask_location: {np.unique(mask_array)}")
211
- print(f"Number of masked pixels after get_mask_location: {np.sum(mask_array == 255)}")
212
-
213
- plt.figure(figsize=(10, 10))
214
- plt.imshow(mask_array, cmap='gray')
215
- plt.title(f"Mask after get_mask_location for {category}")
216
- plt.savefig(f"mask_after_get_mask_location_{category}.png")
217
- plt.close()
218
-
219
- mask = mask.resize((768,1024))
220
- print(f"Mask created for category {category}")
221
-
222
- # 최종 마스크 확인
223
- mask_array_final = np.array(mask)
224
- print(f"Final mask shape: {mask_array_final.shape}")
225
- print(f"Unique values in final mask: {np.unique(mask_array_final)}")
226
- print(f"Number of masked pixels in final mask: {np.sum(mask_array_final == 255)}")
227
-
228
- plt.figure(figsize=(10, 10))
229
- plt.imshow(mask_array_final, cmap='gray')
230
- plt.title(f"Final Mask for {category}")
231
- plt.savefig(f"final_mask_{category}.png")
232
- plt.close()
233
-
234
- except Exception as e:
235
- status_message = f"자동 마스크 생성 중 오류가 발생했습니다: {str(e)}. 기본 마스크를 사용합니다."
236
- print(f"Error in mask creation: {str(e)}")
237
- mask = Image.new('L', (768, 1024), 255)
238
  else:
239
- if dict['layers'] and dict['layers'][0]:
240
- mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
241
- else:
242
- mask = Image.new('L', (768, 1024), 255)
243
-
244
  mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
245
  mask_gray = to_pil_image((mask_gray+1.0)/2.0)
246
 
 
247
  human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
248
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
 
 
249
 
250
  args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
251
- pose_img = args.func(args,human_img_arg)
252
- pose_img = pose_img[:,:,::-1]
 
253
  pose_img = Image.fromarray(pose_img).resize((768,1024))
254
-
255
  with torch.no_grad():
 
256
  with torch.cuda.amp.autocast():
257
  with torch.no_grad():
258
- prompt = "((best quality, masterpiece, ultra-detailed, high quality photography, photo realistic)), the model is wearing " + garment_des
259
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, normal quality, low quality, blurry, jpeg artifacts, sketch"
260
  with torch.inference_mode():
261
  (
262
  prompt_embeds,
@@ -269,9 +505,9 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
269
  do_classifier_free_guidance=True,
270
  negative_prompt=negative_prompt,
271
  )
272
-
273
- prompt = "((best quality, masterpiece, ultra-detailed, high quality photography, photo realistic)), a photo of " + garment_des
274
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, normal quality, low quality, blurry, jpeg artifacts, sketch"
275
  if not isinstance(prompt, List):
276
  prompt = [prompt] * 1
277
  if not isinstance(negative_prompt, List):
@@ -289,10 +525,12 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
289
  negative_prompt=negative_prompt,
290
  )
291
 
 
 
292
  pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
293
  garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
294
  generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
295
- result = pipe(
296
  prompt_embeds=prompt_embeds.to(device,torch.float16),
297
  negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
298
  pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
@@ -304,34 +542,22 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
304
  text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
305
  cloth = garm_tensor.to(device,torch.float16),
306
  mask_image=mask,
307
- image=human_img,
308
  height=1024,
309
  width=768,
310
  ip_adapter_image = garm_img.resize((768,1024)),
311
  guidance_scale=2.0,
312
- )
313
-
314
- # 결과 형태 확인 및 처리
315
- if isinstance(result, tuple):
316
- images = result[0]
317
- elif hasattr(result, 'images'):
318
- images = result.images
319
- else:
320
- raise ValueError(f"Unexpected result type: {type(result)}")
321
-
322
- print(f"Result type: {type(result)}")
323
- print(f"Result content: {result}")
324
- print(f"Mask shape: {mask.size}")
325
- print(f"Human image shape: {human_img.size}")
326
- print(f"Garment image shape: {garm_img.size}")
327
- print(f"Output image shape: {images[0].size}")
328
 
329
  if is_checked_crop:
330
- out_img = images[0].resize(crop_size)
331
- human_img_orig.paste(out_img, (int(left), int(top)))
332
- return human_img_orig, mask_gray, status_message
 
333
  else:
334
- return images[0], mask_gray, status_message
 
 
335
 
336
  garm_list = os.listdir(os.path.join(example_path,"cloth"))
337
  garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
@@ -347,55 +573,64 @@ for ex_human in human_list_path:
347
  ex_dict['composite'] = None
348
  human_ex_list.append(ex_dict)
349
 
 
 
 
350
  image_blocks = gr.Blocks(theme="Nymbo/Nymbo_Theme").queue(max_size=12)
 
351
  with image_blocks as demo:
352
- with gr.Column():
353
- try_button = gr.Button(value="가상 피팅 시작")
354
- with gr.Accordion(label="고급 설정", open=False):
355
- with gr.Row():
356
- denoise_steps = gr.Number(label="디노이징 단계", minimum=20, maximum=40, value=30, step=1)
357
- seed = gr.Number(label="시드", minimum=-1, maximum=2147483647, step=1, value=-1)
358
 
 
359
  with gr.Row():
360
  with gr.Column():
361
- imgs = gr.ImageEditor(sources='upload', type="pil", label='인물 사진. 펜으로 마스크 또는 자동 마스킹 사용', interactive=True)
362
  with gr.Row():
363
- is_checked = gr.Checkbox(label="", info="자동 생성 마스크 사용 (5 소요)",value=True)
364
  with gr.Row():
365
- category = gr.Dropdown(
366
- choices=["상의", "하의", "드레스"],
367
- label="카테고리",
368
- value="상의"
 
 
 
369
  )
370
- with gr.Row():
371
- is_checked_crop = gr.Checkbox(label="예", info="자동 자르기 및 크기 조정 사용",value=False)
372
-
373
- example = gr.Examples(
374
- inputs=imgs,
375
- examples_per_page=15,
376
- examples=human_ex_list
377
- )
378
 
379
  with gr.Column():
380
- garm_img = gr.Image(label="의류", sources='upload', type="pil")
381
  with gr.Row(elem_id="prompt-container"):
382
  with gr.Row():
383
- prompt = gr.Textbox(label="의류 설명", placeholder="반소매 라운드넥 티셔츠", show_label=True, elem_id="prompt")
384
  example = gr.Examples(
385
  inputs=garm_img,
386
- examples_per_page=16,
387
  examples=garm_list_path)
388
- with gr.Column():
389
- masked_img = gr.Image(label="마스크 적용 이미지", elem_id="masked-img",show_share_button=False)
390
- with gr.Column():
391
- image_out = gr.Image(label="결과", elem_id="output-img",show_share_button=False)
392
 
 
 
 
 
 
 
393
  with gr.Column():
394
- status_message = gr.Textbox(label="상태", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
- try_button.click(fn=start_tryon,
397
- inputs=[imgs, garm_img, prompt, is_checked, is_checked_crop, denoise_steps, seed, category],
398
- outputs=[image_out, masked_img, status_message],
399
- api_name='tryon')
400
 
401
  image_blocks.launch(auth=("gini","pick"))
 
 
1
  import gradio as gr
2
+ import spaces
3
  from PIL import Image
4
  from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
5
  from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
 
16
  import torch
17
  import os
18
  from transformers import AutoTokenizer
 
19
  import numpy as np
20
  from utils_mask import get_mask_location
21
  from torchvision import transforms
 
25
  from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
26
  from torchvision.transforms.functional import to_pil_image
27
 
28
+
29
  def pil_to_binary_mask(pil_image, threshold=0):
30
  np_image = np.array(pil_image)
31
  grayscale_image = Image.fromarray(np_image).convert("L")
 
39
  output_mask = Image.fromarray(mask)
40
  return output_mask
41
 
 
 
42
 
43
+ base_path = 'Roopansh/Ailusion-VTON-DEMO-v1.1'
44
+ example_path = os.path.join(os.path.dirname(__file__), 'example')
45
+
46
+ unet = UNet2DConditionModel.from_pretrained(
47
+ base_path,
48
+ subfolder="unet",
49
+ torch_dtype=torch.float16,
50
+ )
51
+ unet.requires_grad_(False)
52
+ tokenizer_one = AutoTokenizer.from_pretrained(
53
+ base_path,
54
+ subfolder="tokenizer",
55
+ revision=None,
56
+ use_fast=False,
57
+ )
58
+ tokenizer_two = AutoTokenizer.from_pretrained(
59
+ base_path,
60
+ subfolder="tokenizer_2",
61
+ revision=None,
62
+ use_fast=False,
63
+ )
64
+ noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
65
+
66
+ text_encoder_one = CLIPTextModel.from_pretrained(
67
+ base_path,
68
+ subfolder="text_encoder",
69
+ torch_dtype=torch.float16,
70
+ )
71
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
72
+ base_path,
73
+ subfolder="text_encoder_2",
74
+ torch_dtype=torch.float16,
75
+ )
76
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
77
+ base_path,
78
+ subfolder="image_encoder",
79
+ torch_dtype=torch.float16,
80
+ )
81
+ vae = AutoencoderKL.from_pretrained(base_path,
82
+ subfolder="vae",
83
+ torch_dtype=torch.float16,
84
+ )
85
+
86
+ # "stabilityai/stable-diffusion-xl-base-1.0",
87
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
88
+ base_path,
89
+ subfolder="unet_encoder",
90
+ torch_dtype=torch.float16,
91
+ )
92
+
93
+ parsing_model = Parsing(0)
94
+ openpose_model = OpenPose(0)
95
+
96
+ UNet_Encoder.requires_grad_(False)
97
+ image_encoder.requires_grad_(False)
98
+ vae.requires_grad_(False)
99
+ unet.requires_grad_(False)
100
+ text_encoder_one.requires_grad_(False)
101
+ text_encoder_two.requires_grad_(False)
102
+ tensor_transfrom = transforms.Compose(
103
+ [
104
+ transforms.ToTensor(),
105
+ transforms.Normalize([0.5], [0.5]),
106
+ ]
107
+ )
108
+
109
+ pipe = TryonPipeline.from_pretrained(
110
+ base_path,
111
+ unet=unet,
112
+ vae=vae,
113
+ feature_extractor= CLIPImageProcessor(),
114
+ text_encoder = text_encoder_one,
115
+ text_encoder_2 = text_encoder_two,
116
+ tokenizer = tokenizer_one,
117
+ tokenizer_2 = tokenizer_two,
118
+ scheduler = noise_scheduler,
119
+ image_encoder=image_encoder,
120
+ torch_dtype=torch.float16,
121
+ )
122
+ pipe.unet_encoder = UNet_Encoder
123
 
124
+ @spaces.GPU(duration=120)
125
+ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed):
126
+ device = "cuda"
127
 
128
+ openpose_model.preprocessor.body_estimation.model.to(device)
129
+ pipe.to(device)
130
+ pipe.unet_encoder.to(device)
131
+
132
+ garm_img= garm_img.convert("RGB").resize((768,1024))
133
+ human_img_orig = dict["background"].convert("RGB")
134
 
135
+ if is_checked_crop:
136
+ width, height = human_img_orig.size
137
+ target_width = int(min(width, height * (3 / 4)))
138
+ target_height = int(min(height, width * (4 / 3)))
139
+ left = (width - target_width) / 2
140
+ top = (height - target_height) / 2
141
+ right = (width + target_width) / 2
142
+ bottom = (height + target_height) / 2
143
+ cropped_img = human_img_orig.crop((left, top, right, bottom))
144
+ crop_size = cropped_img.size
145
+ human_img = cropped_img.resize((768,1024))
 
 
 
 
146
  else:
147
+ human_img = human_img_orig.resize((768,1024))
148
+
149
+
150
+ if is_checked:
151
+ keypoints = openpose_model(human_img.resize((384,512)))
152
+ model_parse, _ = parsing_model(human_img.resize((384,512)))
153
+ mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
154
+ mask = mask.resize((768,1024))
155
+ else:
156
+ mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
157
+ # mask = transforms.ToTensor()(mask)
158
+ # mask = mask.unsqueeze(0)
159
+ mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
160
+ mask_gray = to_pil_image((mask_gray+1.0)/2.0)
161
+
162
+
163
+ human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
164
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
165
+
166
 
167
+
168
+ args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
169
+ # verbosity = getattr(args, "verbosity", None)
170
+ pose_img = args.func(args,human_img_arg)
171
+ pose_img = pose_img[:,:,::-1]
172
+ pose_img = Image.fromarray(pose_img).resize((768,1024))
 
173
 
174
+ with torch.no_grad():
175
+ # Extract the images
176
+ with torch.cuda.amp.autocast():
177
+ with torch.no_grad():
178
+ prompt = "model is wearing " + garment_des
179
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
180
+ with torch.inference_mode():
181
+ (
182
+ prompt_embeds,
183
+ negative_prompt_embeds,
184
+ pooled_prompt_embeds,
185
+ negative_pooled_prompt_embeds,
186
+ ) = pipe.encode_prompt(
187
+ prompt,
188
+ num_images_per_prompt=1,
189
+ do_classifier_free_guidance=True,
190
+ negative_prompt=negative_prompt,
191
+ )
192
+
193
+ prompt = "a photo of " + garment_des
194
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
195
+ if not isinstance(prompt, List):
196
+ prompt = [prompt] * 1
197
+ if not isinstance(negative_prompt, List):
198
+ negative_prompt = [negative_prompt] * 1
199
+ with torch.inference_mode():
200
+ (
201
+ prompt_embeds_c,
202
+ _,
203
+ _,
204
+ _,
205
+ ) = pipe.encode_prompt(
206
+ prompt,
207
+ num_images_per_prompt=1,
208
+ do_classifier_free_guidance=False,
209
+ negative_prompt=negative_prompt,
210
+ )
211
+
212
+
213
+
214
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
215
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
216
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
217
+ images = pipe(
218
+ prompt_embeds=prompt_embeds.to(device,torch.float16),
219
+ negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
220
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
221
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
222
+ num_inference_steps=denoise_steps,
223
+ generator=generator,
224
+ strength = 1.0,
225
+ pose_img = pose_img.to(device,torch.float16),
226
+ text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
227
+ cloth = garm_tensor.to(device,torch.float16),
228
+ mask_image=mask,
229
+ image=human_img,
230
+ height=1024,
231
+ width=768,
232
+ ip_adapter_image = garm_img.resize((768,1024)),
233
+ guidance_scale=2.0,
234
+ )[0]
235
+
236
+ if is_checked_crop:
237
+ out_img = images[0].resize(crop_size)
238
+ human_img_orig.paste(out_img, (int(left), int(top)))
239
+ # return human_img_orig, mask_gray
240
+ return human_img_orig
241
+ else:
242
+ # return images[0], mask_gray
243
+ return images[0]
244
+ # return images[0], mask_gray
245
+
246
+ garm_list = os.listdir(os.path.join(example_path,"cloth"))
247
+ garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
248
+
249
+ human_list = os.listdir(os.path.join(example_path,"human"))
250
+ human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
251
+
252
+ human_ex_list = []
253
+ for ex_human in human_list_path:
254
+ ex_dict= {}
255
+ ex_dict['background'] = ex_human
256
+ ex_dict['layers'] = None
257
+ ex_dict['composite'] = None
258
+ human_ex_list.append(ex_dict)
259
+
260
+ ##default human
261
+
262
+
263
+ image_blocks = gr.Blocks().queue()
264
+ with image_blocks as demo:
265
+ # gr.Markdown("## AILUSION VTON DEMO 👕👔👚")
266
+ # gr.Markdown("Virtual Try-on with your image and garment image.")
267
 
268
+ with gr.Row():
269
+ with gr.Column():
270
+ imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
271
+ with gr.Row():
272
+ is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
273
+ with gr.Row():
274
+ is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
275
+
276
+ with gr.Row(equal_height=True):
277
+ example = gr.Examples(
278
+ inputs=imgs,
279
+ examples_per_page=5,
280
+ examples=human_ex_list
281
+ )
282
+
283
+ with gr.Column():
284
+ garm_img = gr.Image(label="Garment", sources='upload', type="pil")
285
+ with gr.Row(elem_id="prompt-container"):
286
+ with gr.Row():
287
+ prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
288
+ example = gr.Examples(
289
+ inputs=garm_img,
290
+ examples_per_page=8,
291
+ examples=garm_list_path)
292
+ # with gr.Column():
293
+ # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
294
+ # masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
295
+
296
+ # masked_img = ()
297
+
298
+ with gr.Column():
299
+ # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
300
+ image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
301
+
302
+ with gr.Column():
303
+ try_button = gr.Button(value="Try-on")
304
+ with gr.Accordion(label="Advanced Settings", open=False):
305
+ with gr.Row():
306
+ denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
307
+ seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
308
+
309
+
310
+
311
+ try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out], api_name='tryon')
312
+
313
+
314
+
315
+
316
+ image_blocks.launch()
317
+ import gradio as gr
318
+ import spaces
319
+ from PIL import Image
320
+ from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
321
+ from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
322
+ from src.unet_hacked_tryon import UNet2DConditionModel
323
+ from transformers import (
324
+ CLIPImageProcessor,
325
+ CLIPVisionModelWithProjection,
326
+ CLIPTextModel,
327
+ CLIPTextModelWithProjection,
328
+ )
329
+ from diffusers import DDPMScheduler,AutoencoderKL
330
+ from typing import List
331
+
332
+ import torch
333
+ import os
334
+ from transformers import AutoTokenizer
335
+ import numpy as np
336
+ from utils_mask import get_mask_location
337
+ from torchvision import transforms
338
+ import apply_net
339
+ from preprocess.humanparsing.run_parsing import Parsing
340
+ from preprocess.openpose.run_openpose import OpenPose
341
+ from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
342
+ from torchvision.transforms.functional import to_pil_image
343
+
344
+
345
+ def pil_to_binary_mask(pil_image, threshold=0):
346
+ np_image = np.array(pil_image)
347
+ grayscale_image = Image.fromarray(np_image).convert("L")
348
+ binary_mask = np.array(grayscale_image) > threshold
349
+ mask = np.zeros(binary_mask.shape, dtype=np.uint8)
350
+ for i in range(binary_mask.shape[0]):
351
+ for j in range(binary_mask.shape[1]):
352
+ if binary_mask[i,j] == True :
353
+ mask[i,j] = 1
354
+ mask = (mask*255).astype(np.uint8)
355
+ output_mask = Image.fromarray(mask)
356
+ return output_mask
357
 
358
 
359
+ base_path = 'Roopansh/Ailusion-VTON-DEMO-v1.1'
360
  example_path = os.path.join(os.path.dirname(__file__), 'example')
361
 
362
  unet = UNet2DConditionModel.from_pretrained(
 
399
  torch_dtype=torch.float16,
400
  )
401
 
402
+ # "stabilityai/stable-diffusion-xl-base-1.0",
403
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
404
  base_path,
405
  subfolder="unet_encoder",
 
437
  )
438
  pipe.unet_encoder = UNet_Encoder
439
 
440
+ @spaces.GPU(duration=120)
441
+ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed):
442
  device = "cuda"
443
+
444
  openpose_model.preprocessor.body_estimation.model.to(device)
445
  pipe.to(device)
446
  pipe.unet_encoder.to(device)
447
 
448
+ garm_img= garm_img.convert("RGB").resize((768,1024))
449
+ human_img_orig = dict["background"].convert("RGB")
450
+
451
  if is_checked_crop:
452
  width, height = human_img_orig.size
453
  target_width = int(min(width, height * (3 / 4)))
 
463
  human_img = human_img_orig.resize((768,1024))
464
 
465
 
 
466
  if is_checked:
467
+ keypoints = openpose_model(human_img.resize((384,512)))
468
+ model_parse, _ = parsing_model(human_img.resize((384,512)))
469
+ mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
470
+ mask = mask.resize((768,1024))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  else:
472
+ mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
473
+ # mask = transforms.ToTensor()(mask)
474
+ # mask = mask.unsqueeze(0)
 
 
475
  mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
476
  mask_gray = to_pil_image((mask_gray+1.0)/2.0)
477
 
478
+
479
  human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
480
  human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
481
+
482
+
483
 
484
  args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
485
+ # verbosity = getattr(args, "verbosity", None)
486
+ pose_img = args.func(args,human_img_arg)
487
+ pose_img = pose_img[:,:,::-1]
488
  pose_img = Image.fromarray(pose_img).resize((768,1024))
489
+
490
  with torch.no_grad():
491
+ # Extract the images
492
  with torch.cuda.amp.autocast():
493
  with torch.no_grad():
494
+ prompt = "model is wearing " + garment_des
495
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
496
  with torch.inference_mode():
497
  (
498
  prompt_embeds,
 
505
  do_classifier_free_guidance=True,
506
  negative_prompt=negative_prompt,
507
  )
508
+
509
+ prompt = "a photo of " + garment_des
510
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
511
  if not isinstance(prompt, List):
512
  prompt = [prompt] * 1
513
  if not isinstance(negative_prompt, List):
 
525
  negative_prompt=negative_prompt,
526
  )
527
 
528
+
529
+
530
  pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
531
  garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
532
  generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
533
+ images = pipe(
534
  prompt_embeds=prompt_embeds.to(device,torch.float16),
535
  negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
536
  pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
 
542
  text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
543
  cloth = garm_tensor.to(device,torch.float16),
544
  mask_image=mask,
545
+ image=human_img,
546
  height=1024,
547
  width=768,
548
  ip_adapter_image = garm_img.resize((768,1024)),
549
  guidance_scale=2.0,
550
+ )[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
 
552
  if is_checked_crop:
553
+ out_img = images[0].resize(crop_size)
554
+ human_img_orig.paste(out_img, (int(left), int(top)))
555
+ # return human_img_orig, mask_gray
556
+ return human_img_orig
557
  else:
558
+ # return images[0], mask_gray
559
+ return images[0]
560
+ # return images[0], mask_gray
561
 
562
  garm_list = os.listdir(os.path.join(example_path,"cloth"))
563
  garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
 
573
  ex_dict['composite'] = None
574
  human_ex_list.append(ex_dict)
575
 
576
+ ##default human
577
+
578
+
579
  image_blocks = gr.Blocks(theme="Nymbo/Nymbo_Theme").queue(max_size=12)
580
+
581
  with image_blocks as demo:
 
 
 
 
 
 
582
 
583
+
584
  with gr.Row():
585
  with gr.Column():
586
+ imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
587
  with gr.Row():
588
+ is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
589
  with gr.Row():
590
+ is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
591
+
592
+ with gr.Row(equal_height=True):
593
+ example = gr.Examples(
594
+ inputs=imgs,
595
+ examples_per_page=5,
596
+ examples=human_ex_list
597
  )
 
 
 
 
 
 
 
 
598
 
599
  with gr.Column():
600
+ garm_img = gr.Image(label="Garment", sources='upload', type="pil")
601
  with gr.Row(elem_id="prompt-container"):
602
  with gr.Row():
603
+ prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
604
  example = gr.Examples(
605
  inputs=garm_img,
606
+ examples_per_page=8,
607
  examples=garm_list_path)
608
+ # with gr.Column():
609
+ # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
610
+ # masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
 
611
 
612
+ # masked_img = ()
613
+
614
+ with gr.Column():
615
+ # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
616
+ image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
617
+
618
  with gr.Column():
619
+ try_button = gr.Button(value="Try-on")
620
+ with gr.Accordion(label="Advanced Settings", open=False):
621
+ with gr.Row():
622
+ denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
623
+ seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
624
+
625
+
626
+
627
+ try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out], api_name='tryon')
628
+
629
+
630
+
631
+
632
+
633
+
634
 
 
 
 
 
635
 
636
  image_blocks.launch(auth=("gini","pick"))