atlury commited on
Commit
b53c368
β€’
1 Parent(s): 211b491

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +14 -12
  2. app.py +313 -315
  3. apply_net.py +359 -359
  4. requirements.txt +22 -22
  5. utils_mask.py +167 -167
README.md CHANGED
@@ -1,12 +1,14 @@
1
- ---
2
- title: Jiovirtualtryon
3
- emoji: πŸš€
4
- colorFrom: indigo
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 4.39.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
+ ---
2
+ title: IDM VTON
3
+ emoji: πŸ‘•πŸ‘”πŸ‘š
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.24.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-sa-4.0
11
+ short_description: High-fidelity Virtual Try-on
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,315 +1,313 @@
1
- import spaces
2
- import gradio as gr
3
-
4
- from PIL import Image
5
- from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
6
- from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
7
- from src.unet_hacked_tryon import UNet2DConditionModel
8
- from transformers import (
9
- CLIPImageProcessor,
10
- CLIPVisionModelWithProjection,
11
- CLIPTextModel,
12
- CLIPTextModelWithProjection,
13
- )
14
- from diffusers import DDPMScheduler,AutoencoderKL
15
- from typing import List
16
-
17
- import torch
18
- import os
19
- from transformers import AutoTokenizer
20
-
21
- import numpy as np
22
- from utils_mask import get_mask_location
23
- from torchvision import transforms
24
- import apply_net
25
- from preprocess.humanparsing.run_parsing import Parsing
26
- from preprocess.openpose.run_openpose import OpenPose
27
- from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
28
- from torchvision.transforms.functional import to_pil_image
29
-
30
-
31
- def pil_to_binary_mask(pil_image, threshold=0):
32
- np_image = np.array(pil_image)
33
- grayscale_image = Image.fromarray(np_image).convert("L")
34
- binary_mask = np.array(grayscale_image) > threshold
35
- mask = np.zeros(binary_mask.shape, dtype=np.uint8)
36
- for i in range(binary_mask.shape[0]):
37
- for j in range(binary_mask.shape[1]):
38
- if binary_mask[i,j] == True :
39
- mask[i,j] = 1
40
- mask = (mask*255).astype(np.uint8)
41
- output_mask = Image.fromarray(mask)
42
- return output_mask
43
-
44
-
45
- base_path = 'yisol/IDM-VTON'
46
- example_path = os.path.join(os.path.dirname(__file__), 'example')
47
-
48
- unet = UNet2DConditionModel.from_pretrained(
49
- base_path,
50
- subfolder="unet",
51
- torch_dtype=torch.float16,
52
- )
53
- unet.requires_grad_(False)
54
- tokenizer_one = AutoTokenizer.from_pretrained(
55
- base_path,
56
- subfolder="tokenizer",
57
- revision=None,
58
- use_fast=False,
59
- )
60
- tokenizer_two = AutoTokenizer.from_pretrained(
61
- base_path,
62
- subfolder="tokenizer_2",
63
- revision=None,
64
- use_fast=False,
65
- )
66
- noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
67
-
68
- text_encoder_one = CLIPTextModel.from_pretrained(
69
- base_path,
70
- subfolder="text_encoder",
71
- torch_dtype=torch.float16,
72
- )
73
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
74
- base_path,
75
- subfolder="text_encoder_2",
76
- torch_dtype=torch.float16,
77
- )
78
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
79
- base_path,
80
- subfolder="image_encoder",
81
- torch_dtype=torch.float16,
82
- )
83
- vae = AutoencoderKL.from_pretrained(base_path,
84
- subfolder="vae",
85
- torch_dtype=torch.float16,
86
- )
87
-
88
- # "stabilityai/stable-diffusion-xl-base-1.0",
89
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
90
- base_path,
91
- subfolder="unet_encoder",
92
- torch_dtype=torch.float16,
93
- )
94
-
95
- parsing_model = Parsing(0)
96
- openpose_model = OpenPose(0)
97
-
98
- UNet_Encoder.requires_grad_(False)
99
- image_encoder.requires_grad_(False)
100
- vae.requires_grad_(False)
101
- unet.requires_grad_(False)
102
- text_encoder_one.requires_grad_(False)
103
- text_encoder_two.requires_grad_(False)
104
- tensor_transfrom = transforms.Compose(
105
- [
106
- transforms.ToTensor(),
107
- transforms.Normalize([0.5], [0.5]),
108
- ]
109
- )
110
-
111
- pipe = TryonPipeline.from_pretrained(
112
- base_path,
113
- unet=unet,
114
- vae=vae,
115
- feature_extractor= CLIPImageProcessor(),
116
- text_encoder = text_encoder_one,
117
- text_encoder_2 = text_encoder_two,
118
- tokenizer = tokenizer_one,
119
- tokenizer_2 = tokenizer_two,
120
- scheduler = noise_scheduler,
121
- image_encoder=image_encoder,
122
- torch_dtype=torch.float16,
123
- )
124
- pipe.unet_encoder = UNet_Encoder
125
-
126
- @spaces.GPU
127
- def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed):
128
- device = "cuda"
129
-
130
- openpose_model.preprocessor.body_estimation.model.to(device)
131
- pipe.to(device)
132
- pipe.unet_encoder.to(device)
133
-
134
- garm_img= garm_img.convert("RGB").resize((768,1024))
135
- human_img_orig = dict["background"].convert("RGB")
136
-
137
- if is_checked_crop:
138
- width, height = human_img_orig.size
139
- target_width = int(min(width, height * (3 / 4)))
140
- target_height = int(min(height, width * (4 / 3)))
141
- left = (width - target_width) / 2
142
- top = (height - target_height) / 2
143
- right = (width + target_width) / 2
144
- bottom = (height + target_height) / 2
145
- cropped_img = human_img_orig.crop((left, top, right, bottom))
146
- crop_size = cropped_img.size
147
- human_img = cropped_img.resize((768,1024))
148
- else:
149
- human_img = human_img_orig.resize((768,1024))
150
-
151
-
152
- if is_checked:
153
- keypoints = openpose_model(human_img.resize((384,512)))
154
- model_parse, _ = parsing_model(human_img.resize((384,512)))
155
- mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
156
- mask = mask.resize((768,1024))
157
- else:
158
- mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
159
- # mask = transforms.ToTensor()(mask)
160
- # mask = mask.unsqueeze(0)
161
- mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
162
- mask_gray = to_pil_image((mask_gray+1.0)/2.0)
163
-
164
-
165
- human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
166
- human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
167
-
168
-
169
-
170
- args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
171
- # verbosity = getattr(args, "verbosity", None)
172
- pose_img = args.func(args,human_img_arg)
173
- pose_img = pose_img[:,:,::-1]
174
- pose_img = Image.fromarray(pose_img).resize((768,1024))
175
-
176
- with torch.no_grad():
177
- # Extract the images
178
- with torch.cuda.amp.autocast():
179
- with torch.no_grad():
180
- prompt = "model is wearing " + garment_des
181
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
182
- with torch.inference_mode():
183
- (
184
- prompt_embeds,
185
- negative_prompt_embeds,
186
- pooled_prompt_embeds,
187
- negative_pooled_prompt_embeds,
188
- ) = pipe.encode_prompt(
189
- prompt,
190
- num_images_per_prompt=1,
191
- do_classifier_free_guidance=True,
192
- negative_prompt=negative_prompt,
193
- )
194
-
195
- prompt = "a photo of " + garment_des
196
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
197
- if not isinstance(prompt, List):
198
- prompt = [prompt] * 1
199
- if not isinstance(negative_prompt, List):
200
- negative_prompt = [negative_prompt] * 1
201
- with torch.inference_mode():
202
- (
203
- prompt_embeds_c,
204
- _,
205
- _,
206
- _,
207
- ) = pipe.encode_prompt(
208
- prompt,
209
- num_images_per_prompt=1,
210
- do_classifier_free_guidance=False,
211
- negative_prompt=negative_prompt,
212
- )
213
-
214
-
215
-
216
- pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
217
- garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
218
- generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
219
- images = pipe(
220
- prompt_embeds=prompt_embeds.to(device,torch.float16),
221
- negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
222
- pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
223
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
224
- num_inference_steps=denoise_steps,
225
- generator=generator,
226
- strength = 1.0,
227
- pose_img = pose_img.to(device,torch.float16),
228
- text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
229
- cloth = garm_tensor.to(device,torch.float16),
230
- mask_image=mask,
231
- image=human_img,
232
- height=1024,
233
- width=768,
234
- ip_adapter_image = garm_img.resize((768,1024)),
235
- guidance_scale=2.0,
236
- )[0]
237
-
238
- if is_checked_crop:
239
- out_img = images[0].resize(crop_size)
240
- human_img_orig.paste(out_img, (int(left), int(top)))
241
- return human_img_orig, mask_gray
242
- else:
243
- return images[0], mask_gray
244
- # return images[0], mask_gray
245
-
246
- garm_list = os.listdir(os.path.join(example_path,"cloth"))
247
- garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
248
-
249
- human_list = os.listdir(os.path.join(example_path,"human"))
250
- human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
251
-
252
- human_ex_list = []
253
- for ex_human in human_list_path:
254
- ex_dict= {}
255
- ex_dict['background'] = ex_human
256
- ex_dict['layers'] = None
257
- ex_dict['composite'] = None
258
- human_ex_list.append(ex_dict)
259
-
260
- ##default human
261
-
262
-
263
- image_blocks = gr.Blocks().queue()
264
- with image_blocks as demo:
265
- gr.Markdown("## IDM-VTON πŸ‘•πŸ‘”πŸ‘š")
266
- gr.Markdown("Virtual Try-on with your image and garment image. Check out the [source codes](https://github.com/yisol/IDM-VTON) and the [model](https://huggingface.co/yisol/IDM-VTON)")
267
- with gr.Row():
268
- with gr.Column():
269
- imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
270
- with gr.Row():
271
- is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
272
- with gr.Row():
273
- is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
274
-
275
- example = gr.Examples(
276
- inputs=imgs,
277
- examples_per_page=10,
278
- examples=human_ex_list
279
- )
280
-
281
- with gr.Column():
282
- garm_img = gr.Image(label="Garment", sources='upload', type="pil")
283
- with gr.Row(elem_id="prompt-container"):
284
- with gr.Row():
285
- prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
286
- example = gr.Examples(
287
- inputs=garm_img,
288
- examples_per_page=8,
289
- examples=garm_list_path)
290
- with gr.Column():
291
- # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
292
- masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
293
- with gr.Column():
294
- # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
295
- image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
296
-
297
-
298
-
299
-
300
- with gr.Column():
301
- try_button = gr.Button(value="Try-on")
302
- with gr.Accordion(label="Advanced Settings", open=False):
303
- with gr.Row():
304
- denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
305
- seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
306
-
307
-
308
-
309
- try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
310
-
311
-
312
-
313
-
314
- image_blocks.launch()
315
-
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
4
+ from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
5
+ from src.unet_hacked_tryon import UNet2DConditionModel
6
+ from transformers import (
7
+ CLIPImageProcessor,
8
+ CLIPVisionModelWithProjection,
9
+ CLIPTextModel,
10
+ CLIPTextModelWithProjection,
11
+ )
12
+ from diffusers import DDPMScheduler,AutoencoderKL
13
+ from typing import List
14
+
15
+ import torch
16
+ import os
17
+ from transformers import AutoTokenizer
18
+ import spaces
19
+ import numpy as np
20
+ from utils_mask import get_mask_location
21
+ from torchvision import transforms
22
+ import apply_net
23
+ from preprocess.humanparsing.run_parsing import Parsing
24
+ from preprocess.openpose.run_openpose import OpenPose
25
+ from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
26
+ from torchvision.transforms.functional import to_pil_image
27
+
28
+
29
+ def pil_to_binary_mask(pil_image, threshold=0):
30
+ np_image = np.array(pil_image)
31
+ grayscale_image = Image.fromarray(np_image).convert("L")
32
+ binary_mask = np.array(grayscale_image) > threshold
33
+ mask = np.zeros(binary_mask.shape, dtype=np.uint8)
34
+ for i in range(binary_mask.shape[0]):
35
+ for j in range(binary_mask.shape[1]):
36
+ if binary_mask[i,j] == True :
37
+ mask[i,j] = 1
38
+ mask = (mask*255).astype(np.uint8)
39
+ output_mask = Image.fromarray(mask)
40
+ return output_mask
41
+
42
+
43
+ base_path = 'yisol/IDM-VTON'
44
+ example_path = os.path.join(os.path.dirname(__file__), 'example')
45
+
46
+ unet = UNet2DConditionModel.from_pretrained(
47
+ base_path,
48
+ subfolder="unet",
49
+ torch_dtype=torch.float16,
50
+ )
51
+ unet.requires_grad_(False)
52
+ tokenizer_one = AutoTokenizer.from_pretrained(
53
+ base_path,
54
+ subfolder="tokenizer",
55
+ revision=None,
56
+ use_fast=False,
57
+ )
58
+ tokenizer_two = AutoTokenizer.from_pretrained(
59
+ base_path,
60
+ subfolder="tokenizer_2",
61
+ revision=None,
62
+ use_fast=False,
63
+ )
64
+ noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
65
+
66
+ text_encoder_one = CLIPTextModel.from_pretrained(
67
+ base_path,
68
+ subfolder="text_encoder",
69
+ torch_dtype=torch.float16,
70
+ )
71
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
72
+ base_path,
73
+ subfolder="text_encoder_2",
74
+ torch_dtype=torch.float16,
75
+ )
76
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
77
+ base_path,
78
+ subfolder="image_encoder",
79
+ torch_dtype=torch.float16,
80
+ )
81
+ vae = AutoencoderKL.from_pretrained(base_path,
82
+ subfolder="vae",
83
+ torch_dtype=torch.float16,
84
+ )
85
+
86
+ # "stabilityai/stable-diffusion-xl-base-1.0",
87
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
88
+ base_path,
89
+ subfolder="unet_encoder",
90
+ torch_dtype=torch.float16,
91
+ )
92
+
93
+ parsing_model = Parsing(0)
94
+ openpose_model = OpenPose(0)
95
+
96
+ UNet_Encoder.requires_grad_(False)
97
+ image_encoder.requires_grad_(False)
98
+ vae.requires_grad_(False)
99
+ unet.requires_grad_(False)
100
+ text_encoder_one.requires_grad_(False)
101
+ text_encoder_two.requires_grad_(False)
102
+ tensor_transfrom = transforms.Compose(
103
+ [
104
+ transforms.ToTensor(),
105
+ transforms.Normalize([0.5], [0.5]),
106
+ ]
107
+ )
108
+
109
+ pipe = TryonPipeline.from_pretrained(
110
+ base_path,
111
+ unet=unet,
112
+ vae=vae,
113
+ feature_extractor= CLIPImageProcessor(),
114
+ text_encoder = text_encoder_one,
115
+ text_encoder_2 = text_encoder_two,
116
+ tokenizer = tokenizer_one,
117
+ tokenizer_2 = tokenizer_two,
118
+ scheduler = noise_scheduler,
119
+ image_encoder=image_encoder,
120
+ torch_dtype=torch.float16,
121
+ )
122
+ pipe.unet_encoder = UNet_Encoder
123
+
124
+ @spaces.GPU
125
+ def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed):
126
+ device = "cuda"
127
+
128
+ openpose_model.preprocessor.body_estimation.model.to(device)
129
+ pipe.to(device)
130
+ pipe.unet_encoder.to(device)
131
+
132
+ garm_img= garm_img.convert("RGB").resize((768,1024))
133
+ human_img_orig = dict["background"].convert("RGB")
134
+
135
+ if is_checked_crop:
136
+ width, height = human_img_orig.size
137
+ target_width = int(min(width, height * (3 / 4)))
138
+ target_height = int(min(height, width * (4 / 3)))
139
+ left = (width - target_width) / 2
140
+ top = (height - target_height) / 2
141
+ right = (width + target_width) / 2
142
+ bottom = (height + target_height) / 2
143
+ cropped_img = human_img_orig.crop((left, top, right, bottom))
144
+ crop_size = cropped_img.size
145
+ human_img = cropped_img.resize((768,1024))
146
+ else:
147
+ human_img = human_img_orig.resize((768,1024))
148
+
149
+
150
+ if is_checked:
151
+ keypoints = openpose_model(human_img.resize((384,512)))
152
+ model_parse, _ = parsing_model(human_img.resize((384,512)))
153
+ mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
154
+ mask = mask.resize((768,1024))
155
+ else:
156
+ mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
157
+ # mask = transforms.ToTensor()(mask)
158
+ # mask = mask.unsqueeze(0)
159
+ mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
160
+ mask_gray = to_pil_image((mask_gray+1.0)/2.0)
161
+
162
+
163
+ human_img_arg = _apply_exif_orientation(human_img.resize((384,512)))
164
+ human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
165
+
166
+
167
+
168
+ args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
169
+ # verbosity = getattr(args, "verbosity", None)
170
+ pose_img = args.func(args,human_img_arg)
171
+ pose_img = pose_img[:,:,::-1]
172
+ pose_img = Image.fromarray(pose_img).resize((768,1024))
173
+
174
+ with torch.no_grad():
175
+ # Extract the images
176
+ with torch.cuda.amp.autocast():
177
+ with torch.no_grad():
178
+ prompt = "model is wearing " + garment_des
179
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
180
+ with torch.inference_mode():
181
+ (
182
+ prompt_embeds,
183
+ negative_prompt_embeds,
184
+ pooled_prompt_embeds,
185
+ negative_pooled_prompt_embeds,
186
+ ) = pipe.encode_prompt(
187
+ prompt,
188
+ num_images_per_prompt=1,
189
+ do_classifier_free_guidance=True,
190
+ negative_prompt=negative_prompt,
191
+ )
192
+
193
+ prompt = "a photo of " + garment_des
194
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
195
+ if not isinstance(prompt, List):
196
+ prompt = [prompt] * 1
197
+ if not isinstance(negative_prompt, List):
198
+ negative_prompt = [negative_prompt] * 1
199
+ with torch.inference_mode():
200
+ (
201
+ prompt_embeds_c,
202
+ _,
203
+ _,
204
+ _,
205
+ ) = pipe.encode_prompt(
206
+ prompt,
207
+ num_images_per_prompt=1,
208
+ do_classifier_free_guidance=False,
209
+ negative_prompt=negative_prompt,
210
+ )
211
+
212
+
213
+
214
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16)
215
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16)
216
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
217
+ images = pipe(
218
+ prompt_embeds=prompt_embeds.to(device,torch.float16),
219
+ negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16),
220
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16),
221
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16),
222
+ num_inference_steps=denoise_steps,
223
+ generator=generator,
224
+ strength = 1.0,
225
+ pose_img = pose_img.to(device,torch.float16),
226
+ text_embeds_cloth=prompt_embeds_c.to(device,torch.float16),
227
+ cloth = garm_tensor.to(device,torch.float16),
228
+ mask_image=mask,
229
+ image=human_img,
230
+ height=1024,
231
+ width=768,
232
+ ip_adapter_image = garm_img.resize((768,1024)),
233
+ guidance_scale=2.0,
234
+ )[0]
235
+
236
+ if is_checked_crop:
237
+ out_img = images[0].resize(crop_size)
238
+ human_img_orig.paste(out_img, (int(left), int(top)))
239
+ return human_img_orig, mask_gray
240
+ else:
241
+ return images[0], mask_gray
242
+ # return images[0], mask_gray
243
+
244
+ garm_list = os.listdir(os.path.join(example_path,"cloth"))
245
+ garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list]
246
+
247
+ human_list = os.listdir(os.path.join(example_path,"human"))
248
+ human_list_path = [os.path.join(example_path,"human",human) for human in human_list]
249
+
250
+ human_ex_list = []
251
+ for ex_human in human_list_path:
252
+ ex_dict= {}
253
+ ex_dict['background'] = ex_human
254
+ ex_dict['layers'] = None
255
+ ex_dict['composite'] = None
256
+ human_ex_list.append(ex_dict)
257
+
258
+ ##default human
259
+
260
+
261
+ image_blocks = gr.Blocks().queue()
262
+ with image_blocks as demo:
263
+ gr.Markdown("## IDM-VTON πŸ‘•πŸ‘”πŸ‘š")
264
+ gr.Markdown("Virtual Try-on with your image and garment image. Check out the [source codes](https://github.com/yisol/IDM-VTON) and the [model](https://huggingface.co/yisol/IDM-VTON)")
265
+ with gr.Row():
266
+ with gr.Column():
267
+ imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True)
268
+ with gr.Row():
269
+ is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True)
270
+ with gr.Row():
271
+ is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False)
272
+
273
+ example = gr.Examples(
274
+ inputs=imgs,
275
+ examples_per_page=10,
276
+ examples=human_ex_list
277
+ )
278
+
279
+ with gr.Column():
280
+ garm_img = gr.Image(label="Garment", sources='upload', type="pil")
281
+ with gr.Row(elem_id="prompt-container"):
282
+ with gr.Row():
283
+ prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt")
284
+ example = gr.Examples(
285
+ inputs=garm_img,
286
+ examples_per_page=8,
287
+ examples=garm_list_path)
288
+ with gr.Column():
289
+ # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
290
+ masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False)
291
+ with gr.Column():
292
+ # image_out = gr.Image(label="Output", elem_id="output-img", height=400)
293
+ image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False)
294
+
295
+
296
+
297
+
298
+ with gr.Column():
299
+ try_button = gr.Button(value="Try-on")
300
+ with gr.Accordion(label="Advanced Settings", open=False):
301
+ with gr.Row():
302
+ denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1)
303
+ seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42)
304
+
305
+
306
+
307
+ try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon')
308
+
309
+
310
+
311
+
312
+ image_blocks.launch()
313
+
 
 
apply_net.py CHANGED
@@ -1,359 +1,359 @@
1
- #!/usr/bin/env python3
2
- # Copyright (c) Facebook, Inc. and its affiliates.
3
-
4
- import argparse
5
- import glob
6
- import logging
7
- import os
8
- import sys
9
- from typing import Any, ClassVar, Dict, List
10
- import torch
11
-
12
- from detectron2.config import CfgNode, get_cfg
13
- from detectron2.data.detection_utils import read_image
14
- from detectron2.engine.defaults import DefaultPredictor
15
- from detectron2.structures.instances import Instances
16
- from detectron2.utils.logger import setup_logger
17
-
18
- from densepose import add_densepose_config
19
- from densepose.structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
20
- from densepose.utils.logger import verbosity_to_level
21
- from densepose.vis.base import CompoundVisualizer
22
- from densepose.vis.bounding_box import ScoredBoundingBoxVisualizer
23
- from densepose.vis.densepose_outputs_vertex import (
24
- DensePoseOutputsTextureVisualizer,
25
- DensePoseOutputsVertexVisualizer,
26
- get_texture_atlases,
27
- )
28
- from densepose.vis.densepose_results import (
29
- DensePoseResultsContourVisualizer,
30
- DensePoseResultsFineSegmentationVisualizer,
31
- DensePoseResultsUVisualizer,
32
- DensePoseResultsVVisualizer,
33
- )
34
- from densepose.vis.densepose_results_textures import (
35
- DensePoseResultsVisualizerWithTexture,
36
- get_texture_atlas,
37
- )
38
- from densepose.vis.extractor import (
39
- CompoundExtractor,
40
- DensePoseOutputsExtractor,
41
- DensePoseResultExtractor,
42
- create_extractor,
43
- )
44
-
45
- DOC = """Apply Net - a tool to print / visualize DensePose results
46
- """
47
-
48
- LOGGER_NAME = "apply_net"
49
- logger = logging.getLogger(LOGGER_NAME)
50
-
51
- _ACTION_REGISTRY: Dict[str, "Action"] = {}
52
-
53
-
54
- class Action:
55
- @classmethod
56
- def add_arguments(cls: type, parser: argparse.ArgumentParser):
57
- parser.add_argument(
58
- "-v",
59
- "--verbosity",
60
- action="count",
61
- help="Verbose mode. Multiple -v options increase the verbosity.",
62
- )
63
-
64
-
65
- def register_action(cls: type):
66
- """
67
- Decorator for action classes to automate action registration
68
- """
69
- global _ACTION_REGISTRY
70
- _ACTION_REGISTRY[cls.COMMAND] = cls
71
- return cls
72
-
73
-
74
- class InferenceAction(Action):
75
- @classmethod
76
- def add_arguments(cls: type, parser: argparse.ArgumentParser):
77
- super(InferenceAction, cls).add_arguments(parser)
78
- parser.add_argument("cfg", metavar="<config>", help="Config file")
79
- parser.add_argument("model", metavar="<model>", help="Model file")
80
- parser.add_argument(
81
- "--opts",
82
- help="Modify config options using the command-line 'KEY VALUE' pairs",
83
- default=[],
84
- nargs=argparse.REMAINDER,
85
- )
86
-
87
- @classmethod
88
- def execute(cls: type, args: argparse.Namespace, human_img):
89
- logger.info(f"Loading config from {args.cfg}")
90
- opts = []
91
- cfg = cls.setup_config(args.cfg, args.model, args, opts)
92
- logger.info(f"Loading model from {args.model}")
93
- predictor = DefaultPredictor(cfg)
94
- # logger.info(f"Loading data from {args.input}")
95
- # file_list = cls._get_input_file_list(args.input)
96
- # if len(file_list) == 0:
97
- # logger.warning(f"No input images for {args.input}")
98
- # return
99
- context = cls.create_context(args, cfg)
100
- # for file_name in file_list:
101
- # img = read_image(file_name, format="BGR") # predictor expects BGR image.
102
- with torch.no_grad():
103
- outputs = predictor(human_img)["instances"]
104
- out_pose = cls.execute_on_outputs(context, {"image": human_img}, outputs)
105
- cls.postexecute(context)
106
- return out_pose
107
-
108
- @classmethod
109
- def setup_config(
110
- cls: type, config_fpath: str, model_fpath: str, args: argparse.Namespace, opts: List[str]
111
- ):
112
- cfg = get_cfg()
113
- add_densepose_config(cfg)
114
- cfg.merge_from_file(config_fpath)
115
- cfg.merge_from_list(args.opts)
116
- if opts:
117
- cfg.merge_from_list(opts)
118
- cfg.MODEL.WEIGHTS = model_fpath
119
- cfg.freeze()
120
- return cfg
121
-
122
- @classmethod
123
- def _get_input_file_list(cls: type, input_spec: str):
124
- if os.path.isdir(input_spec):
125
- file_list = [
126
- os.path.join(input_spec, fname)
127
- for fname in os.listdir(input_spec)
128
- if os.path.isfile(os.path.join(input_spec, fname))
129
- ]
130
- elif os.path.isfile(input_spec):
131
- file_list = [input_spec]
132
- else:
133
- file_list = glob.glob(input_spec)
134
- return file_list
135
-
136
-
137
- @register_action
138
- class DumpAction(InferenceAction):
139
- """
140
- Dump action that outputs results to a pickle file
141
- """
142
-
143
- COMMAND: ClassVar[str] = "dump"
144
-
145
- @classmethod
146
- def add_parser(cls: type, subparsers: argparse._SubParsersAction):
147
- parser = subparsers.add_parser(cls.COMMAND, help="Dump model outputs to a file.")
148
- cls.add_arguments(parser)
149
- parser.set_defaults(func=cls.execute)
150
-
151
- @classmethod
152
- def add_arguments(cls: type, parser: argparse.ArgumentParser):
153
- super(DumpAction, cls).add_arguments(parser)
154
- parser.add_argument(
155
- "--output",
156
- metavar="<dump_file>",
157
- default="results.pkl",
158
- help="File name to save dump to",
159
- )
160
-
161
- @classmethod
162
- def execute_on_outputs(
163
- cls: type, context: Dict[str, Any], entry: Dict[str, Any], outputs: Instances
164
- ):
165
- image_fpath = entry["file_name"]
166
- logger.info(f"Processing {image_fpath}")
167
- result = {"file_name": image_fpath}
168
- if outputs.has("scores"):
169
- result["scores"] = outputs.get("scores").cpu()
170
- if outputs.has("pred_boxes"):
171
- result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
172
- if outputs.has("pred_densepose"):
173
- if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
174
- extractor = DensePoseResultExtractor()
175
- elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
176
- extractor = DensePoseOutputsExtractor()
177
- result["pred_densepose"] = extractor(outputs)[0]
178
- context["results"].append(result)
179
-
180
- @classmethod
181
- def create_context(cls: type, args: argparse.Namespace, cfg: CfgNode):
182
- context = {"results": [], "out_fname": args.output}
183
- return context
184
-
185
- @classmethod
186
- def postexecute(cls: type, context: Dict[str, Any]):
187
- out_fname = context["out_fname"]
188
- out_dir = os.path.dirname(out_fname)
189
- if len(out_dir) > 0 and not os.path.exists(out_dir):
190
- os.makedirs(out_dir)
191
- with open(out_fname, "wb") as hFile:
192
- torch.save(context["results"], hFile)
193
- logger.info(f"Output saved to {out_fname}")
194
-
195
-
196
- @register_action
197
- class ShowAction(InferenceAction):
198
- """
199
- Show action that visualizes selected entries on an image
200
- """
201
-
202
- COMMAND: ClassVar[str] = "show"
203
- VISUALIZERS: ClassVar[Dict[str, object]] = {
204
- "dp_contour": DensePoseResultsContourVisualizer,
205
- "dp_segm": DensePoseResultsFineSegmentationVisualizer,
206
- "dp_u": DensePoseResultsUVisualizer,
207
- "dp_v": DensePoseResultsVVisualizer,
208
- "dp_iuv_texture": DensePoseResultsVisualizerWithTexture,
209
- "dp_cse_texture": DensePoseOutputsTextureVisualizer,
210
- "dp_vertex": DensePoseOutputsVertexVisualizer,
211
- "bbox": ScoredBoundingBoxVisualizer,
212
- }
213
-
214
- @classmethod
215
- def add_parser(cls: type, subparsers: argparse._SubParsersAction):
216
- parser = subparsers.add_parser(cls.COMMAND, help="Visualize selected entries")
217
- cls.add_arguments(parser)
218
- parser.set_defaults(func=cls.execute)
219
-
220
- @classmethod
221
- def add_arguments(cls: type, parser: argparse.ArgumentParser):
222
- super(ShowAction, cls).add_arguments(parser)
223
- parser.add_argument(
224
- "visualizations",
225
- metavar="<visualizations>",
226
- help="Comma separated list of visualizations, possible values: "
227
- "[{}]".format(",".join(sorted(cls.VISUALIZERS.keys()))),
228
- )
229
- parser.add_argument(
230
- "--min_score",
231
- metavar="<score>",
232
- default=0.8,
233
- type=float,
234
- help="Minimum detection score to visualize",
235
- )
236
- parser.add_argument(
237
- "--nms_thresh", metavar="<threshold>", default=None, type=float, help="NMS threshold"
238
- )
239
- parser.add_argument(
240
- "--texture_atlas",
241
- metavar="<texture_atlas>",
242
- default=None,
243
- help="Texture atlas file (for IUV texture transfer)",
244
- )
245
- parser.add_argument(
246
- "--texture_atlases_map",
247
- metavar="<texture_atlases_map>",
248
- default=None,
249
- help="JSON string of a dict containing texture atlas files for each mesh",
250
- )
251
- parser.add_argument(
252
- "--output",
253
- metavar="<image_file>",
254
- default="outputres.png",
255
- help="File name to save output to",
256
- )
257
-
258
- @classmethod
259
- def setup_config(
260
- cls: type, config_fpath: str, model_fpath: str, args: argparse.Namespace, opts: List[str]
261
- ):
262
- opts.append("MODEL.ROI_HEADS.SCORE_THRESH_TEST")
263
- opts.append(str(args.min_score))
264
- if args.nms_thresh is not None:
265
- opts.append("MODEL.ROI_HEADS.NMS_THRESH_TEST")
266
- opts.append(str(args.nms_thresh))
267
- cfg = super(ShowAction, cls).setup_config(config_fpath, model_fpath, args, opts)
268
- return cfg
269
-
270
- @classmethod
271
- def execute_on_outputs(
272
- cls: type, context: Dict[str, Any], entry: Dict[str, Any], outputs: Instances
273
- ):
274
- import cv2
275
- import numpy as np
276
- visualizer = context["visualizer"]
277
- extractor = context["extractor"]
278
- # image_fpath = entry["file_name"]
279
- # logger.info(f"Processing {image_fpath}")
280
- image = cv2.cvtColor(entry["image"], cv2.COLOR_BGR2GRAY)
281
- image = np.tile(image[:, :, np.newaxis], [1, 1, 3])
282
- data = extractor(outputs)
283
- image_vis = visualizer.visualize(image, data)
284
-
285
- return image_vis
286
- entry_idx = context["entry_idx"] + 1
287
- out_fname = './image-densepose/' + image_fpath.split('/')[-1]
288
- out_dir = './image-densepose'
289
- out_dir = os.path.dirname(out_fname)
290
- if len(out_dir) > 0 and not os.path.exists(out_dir):
291
- os.makedirs(out_dir)
292
- cv2.imwrite(out_fname, image_vis)
293
- logger.info(f"Output saved to {out_fname}")
294
- context["entry_idx"] += 1
295
-
296
- @classmethod
297
- def postexecute(cls: type, context: Dict[str, Any]):
298
- pass
299
- # python ./apply_net.py show ./configs/densepose_rcnn_R_50_FPN_s1x.yaml https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl /home/alin0222/DressCode/upper_body/images dp_segm -v --opts MODEL.DEVICE cpu
300
-
301
- @classmethod
302
- def _get_out_fname(cls: type, entry_idx: int, fname_base: str):
303
- base, ext = os.path.splitext(fname_base)
304
- return base + ".{0:04d}".format(entry_idx) + ext
305
-
306
- @classmethod
307
- def create_context(cls: type, args: argparse.Namespace, cfg: CfgNode) -> Dict[str, Any]:
308
- vis_specs = args.visualizations.split(",")
309
- visualizers = []
310
- extractors = []
311
- for vis_spec in vis_specs:
312
- texture_atlas = get_texture_atlas(args.texture_atlas)
313
- texture_atlases_dict = get_texture_atlases(args.texture_atlases_map)
314
- vis = cls.VISUALIZERS[vis_spec](
315
- cfg=cfg,
316
- texture_atlas=texture_atlas,
317
- texture_atlases_dict=texture_atlases_dict,
318
- )
319
- visualizers.append(vis)
320
- extractor = create_extractor(vis)
321
- extractors.append(extractor)
322
- visualizer = CompoundVisualizer(visualizers)
323
- extractor = CompoundExtractor(extractors)
324
- context = {
325
- "extractor": extractor,
326
- "visualizer": visualizer,
327
- "out_fname": args.output,
328
- "entry_idx": 0,
329
- }
330
- return context
331
-
332
-
333
- def create_argument_parser() -> argparse.ArgumentParser:
334
- parser = argparse.ArgumentParser(
335
- description=DOC,
336
- formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=120),
337
- )
338
- parser.set_defaults(func=lambda _: parser.print_help(sys.stdout))
339
- subparsers = parser.add_subparsers(title="Actions")
340
- for _, action in _ACTION_REGISTRY.items():
341
- action.add_parser(subparsers)
342
- return parser
343
-
344
-
345
- def main():
346
- parser = create_argument_parser()
347
- args = parser.parse_args()
348
- verbosity = getattr(args, "verbosity", None)
349
- global logger
350
- logger = setup_logger(name=LOGGER_NAME)
351
- logger.setLevel(verbosity_to_level(verbosity))
352
- args.func(args)
353
-
354
-
355
- if __name__ == "__main__":
356
- main()
357
-
358
-
359
- # python ./apply_net.py show ./configs/densepose_rcnn_R_50_FPN_s1x.yaml https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl /home/alin0222/Dresscode/dresses/humanonly dp_segm -v --opts MODEL.DEVICE cuda
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ import argparse
5
+ import glob
6
+ import logging
7
+ import os
8
+ import sys
9
+ from typing import Any, ClassVar, Dict, List
10
+ import torch
11
+
12
+ from detectron2.config import CfgNode, get_cfg
13
+ from detectron2.data.detection_utils import read_image
14
+ from detectron2.engine.defaults import DefaultPredictor
15
+ from detectron2.structures.instances import Instances
16
+ from detectron2.utils.logger import setup_logger
17
+
18
+ from densepose import add_densepose_config
19
+ from densepose.structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
20
+ from densepose.utils.logger import verbosity_to_level
21
+ from densepose.vis.base import CompoundVisualizer
22
+ from densepose.vis.bounding_box import ScoredBoundingBoxVisualizer
23
+ from densepose.vis.densepose_outputs_vertex import (
24
+ DensePoseOutputsTextureVisualizer,
25
+ DensePoseOutputsVertexVisualizer,
26
+ get_texture_atlases,
27
+ )
28
+ from densepose.vis.densepose_results import (
29
+ DensePoseResultsContourVisualizer,
30
+ DensePoseResultsFineSegmentationVisualizer,
31
+ DensePoseResultsUVisualizer,
32
+ DensePoseResultsVVisualizer,
33
+ )
34
+ from densepose.vis.densepose_results_textures import (
35
+ DensePoseResultsVisualizerWithTexture,
36
+ get_texture_atlas,
37
+ )
38
+ from densepose.vis.extractor import (
39
+ CompoundExtractor,
40
+ DensePoseOutputsExtractor,
41
+ DensePoseResultExtractor,
42
+ create_extractor,
43
+ )
44
+
45
+ DOC = """Apply Net - a tool to print / visualize DensePose results
46
+ """
47
+
48
+ LOGGER_NAME = "apply_net"
49
+ logger = logging.getLogger(LOGGER_NAME)
50
+
51
+ _ACTION_REGISTRY: Dict[str, "Action"] = {}
52
+
53
+
54
+ class Action:
55
+ @classmethod
56
+ def add_arguments(cls: type, parser: argparse.ArgumentParser):
57
+ parser.add_argument(
58
+ "-v",
59
+ "--verbosity",
60
+ action="count",
61
+ help="Verbose mode. Multiple -v options increase the verbosity.",
62
+ )
63
+
64
+
65
+ def register_action(cls: type):
66
+ """
67
+ Decorator for action classes to automate action registration
68
+ """
69
+ global _ACTION_REGISTRY
70
+ _ACTION_REGISTRY[cls.COMMAND] = cls
71
+ return cls
72
+
73
+
74
+ class InferenceAction(Action):
75
+ @classmethod
76
+ def add_arguments(cls: type, parser: argparse.ArgumentParser):
77
+ super(InferenceAction, cls).add_arguments(parser)
78
+ parser.add_argument("cfg", metavar="<config>", help="Config file")
79
+ parser.add_argument("model", metavar="<model>", help="Model file")
80
+ parser.add_argument(
81
+ "--opts",
82
+ help="Modify config options using the command-line 'KEY VALUE' pairs",
83
+ default=[],
84
+ nargs=argparse.REMAINDER,
85
+ )
86
+
87
+ @classmethod
88
+ def execute(cls: type, args: argparse.Namespace, human_img):
89
+ logger.info(f"Loading config from {args.cfg}")
90
+ opts = []
91
+ cfg = cls.setup_config(args.cfg, args.model, args, opts)
92
+ logger.info(f"Loading model from {args.model}")
93
+ predictor = DefaultPredictor(cfg)
94
+ # logger.info(f"Loading data from {args.input}")
95
+ # file_list = cls._get_input_file_list(args.input)
96
+ # if len(file_list) == 0:
97
+ # logger.warning(f"No input images for {args.input}")
98
+ # return
99
+ context = cls.create_context(args, cfg)
100
+ # for file_name in file_list:
101
+ # img = read_image(file_name, format="BGR") # predictor expects BGR image.
102
+ with torch.no_grad():
103
+ outputs = predictor(human_img)["instances"]
104
+ out_pose = cls.execute_on_outputs(context, {"image": human_img}, outputs)
105
+ cls.postexecute(context)
106
+ return out_pose
107
+
108
+ @classmethod
109
+ def setup_config(
110
+ cls: type, config_fpath: str, model_fpath: str, args: argparse.Namespace, opts: List[str]
111
+ ):
112
+ cfg = get_cfg()
113
+ add_densepose_config(cfg)
114
+ cfg.merge_from_file(config_fpath)
115
+ cfg.merge_from_list(args.opts)
116
+ if opts:
117
+ cfg.merge_from_list(opts)
118
+ cfg.MODEL.WEIGHTS = model_fpath
119
+ cfg.freeze()
120
+ return cfg
121
+
122
+ @classmethod
123
+ def _get_input_file_list(cls: type, input_spec: str):
124
+ if os.path.isdir(input_spec):
125
+ file_list = [
126
+ os.path.join(input_spec, fname)
127
+ for fname in os.listdir(input_spec)
128
+ if os.path.isfile(os.path.join(input_spec, fname))
129
+ ]
130
+ elif os.path.isfile(input_spec):
131
+ file_list = [input_spec]
132
+ else:
133
+ file_list = glob.glob(input_spec)
134
+ return file_list
135
+
136
+
137
+ @register_action
138
+ class DumpAction(InferenceAction):
139
+ """
140
+ Dump action that outputs results to a pickle file
141
+ """
142
+
143
+ COMMAND: ClassVar[str] = "dump"
144
+
145
+ @classmethod
146
+ def add_parser(cls: type, subparsers: argparse._SubParsersAction):
147
+ parser = subparsers.add_parser(cls.COMMAND, help="Dump model outputs to a file.")
148
+ cls.add_arguments(parser)
149
+ parser.set_defaults(func=cls.execute)
150
+
151
+ @classmethod
152
+ def add_arguments(cls: type, parser: argparse.ArgumentParser):
153
+ super(DumpAction, cls).add_arguments(parser)
154
+ parser.add_argument(
155
+ "--output",
156
+ metavar="<dump_file>",
157
+ default="results.pkl",
158
+ help="File name to save dump to",
159
+ )
160
+
161
+ @classmethod
162
+ def execute_on_outputs(
163
+ cls: type, context: Dict[str, Any], entry: Dict[str, Any], outputs: Instances
164
+ ):
165
+ image_fpath = entry["file_name"]
166
+ logger.info(f"Processing {image_fpath}")
167
+ result = {"file_name": image_fpath}
168
+ if outputs.has("scores"):
169
+ result["scores"] = outputs.get("scores").cpu()
170
+ if outputs.has("pred_boxes"):
171
+ result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
172
+ if outputs.has("pred_densepose"):
173
+ if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
174
+ extractor = DensePoseResultExtractor()
175
+ elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
176
+ extractor = DensePoseOutputsExtractor()
177
+ result["pred_densepose"] = extractor(outputs)[0]
178
+ context["results"].append(result)
179
+
180
+ @classmethod
181
+ def create_context(cls: type, args: argparse.Namespace, cfg: CfgNode):
182
+ context = {"results": [], "out_fname": args.output}
183
+ return context
184
+
185
+ @classmethod
186
+ def postexecute(cls: type, context: Dict[str, Any]):
187
+ out_fname = context["out_fname"]
188
+ out_dir = os.path.dirname(out_fname)
189
+ if len(out_dir) > 0 and not os.path.exists(out_dir):
190
+ os.makedirs(out_dir)
191
+ with open(out_fname, "wb") as hFile:
192
+ torch.save(context["results"], hFile)
193
+ logger.info(f"Output saved to {out_fname}")
194
+
195
+
196
+ @register_action
197
+ class ShowAction(InferenceAction):
198
+ """
199
+ Show action that visualizes selected entries on an image
200
+ """
201
+
202
+ COMMAND: ClassVar[str] = "show"
203
+ VISUALIZERS: ClassVar[Dict[str, object]] = {
204
+ "dp_contour": DensePoseResultsContourVisualizer,
205
+ "dp_segm": DensePoseResultsFineSegmentationVisualizer,
206
+ "dp_u": DensePoseResultsUVisualizer,
207
+ "dp_v": DensePoseResultsVVisualizer,
208
+ "dp_iuv_texture": DensePoseResultsVisualizerWithTexture,
209
+ "dp_cse_texture": DensePoseOutputsTextureVisualizer,
210
+ "dp_vertex": DensePoseOutputsVertexVisualizer,
211
+ "bbox": ScoredBoundingBoxVisualizer,
212
+ }
213
+
214
+ @classmethod
215
+ def add_parser(cls: type, subparsers: argparse._SubParsersAction):
216
+ parser = subparsers.add_parser(cls.COMMAND, help="Visualize selected entries")
217
+ cls.add_arguments(parser)
218
+ parser.set_defaults(func=cls.execute)
219
+
220
+ @classmethod
221
+ def add_arguments(cls: type, parser: argparse.ArgumentParser):
222
+ super(ShowAction, cls).add_arguments(parser)
223
+ parser.add_argument(
224
+ "visualizations",
225
+ metavar="<visualizations>",
226
+ help="Comma separated list of visualizations, possible values: "
227
+ "[{}]".format(",".join(sorted(cls.VISUALIZERS.keys()))),
228
+ )
229
+ parser.add_argument(
230
+ "--min_score",
231
+ metavar="<score>",
232
+ default=0.8,
233
+ type=float,
234
+ help="Minimum detection score to visualize",
235
+ )
236
+ parser.add_argument(
237
+ "--nms_thresh", metavar="<threshold>", default=None, type=float, help="NMS threshold"
238
+ )
239
+ parser.add_argument(
240
+ "--texture_atlas",
241
+ metavar="<texture_atlas>",
242
+ default=None,
243
+ help="Texture atlas file (for IUV texture transfer)",
244
+ )
245
+ parser.add_argument(
246
+ "--texture_atlases_map",
247
+ metavar="<texture_atlases_map>",
248
+ default=None,
249
+ help="JSON string of a dict containing texture atlas files for each mesh",
250
+ )
251
+ parser.add_argument(
252
+ "--output",
253
+ metavar="<image_file>",
254
+ default="outputres.png",
255
+ help="File name to save output to",
256
+ )
257
+
258
+ @classmethod
259
+ def setup_config(
260
+ cls: type, config_fpath: str, model_fpath: str, args: argparse.Namespace, opts: List[str]
261
+ ):
262
+ opts.append("MODEL.ROI_HEADS.SCORE_THRESH_TEST")
263
+ opts.append(str(args.min_score))
264
+ if args.nms_thresh is not None:
265
+ opts.append("MODEL.ROI_HEADS.NMS_THRESH_TEST")
266
+ opts.append(str(args.nms_thresh))
267
+ cfg = super(ShowAction, cls).setup_config(config_fpath, model_fpath, args, opts)
268
+ return cfg
269
+
270
+ @classmethod
271
+ def execute_on_outputs(
272
+ cls: type, context: Dict[str, Any], entry: Dict[str, Any], outputs: Instances
273
+ ):
274
+ import cv2
275
+ import numpy as np
276
+ visualizer = context["visualizer"]
277
+ extractor = context["extractor"]
278
+ # image_fpath = entry["file_name"]
279
+ # logger.info(f"Processing {image_fpath}")
280
+ image = cv2.cvtColor(entry["image"], cv2.COLOR_BGR2GRAY)
281
+ image = np.tile(image[:, :, np.newaxis], [1, 1, 3])
282
+ data = extractor(outputs)
283
+ image_vis = visualizer.visualize(image, data)
284
+
285
+ return image_vis
286
+ entry_idx = context["entry_idx"] + 1
287
+ out_fname = './image-densepose/' + image_fpath.split('/')[-1]
288
+ out_dir = './image-densepose'
289
+ out_dir = os.path.dirname(out_fname)
290
+ if len(out_dir) > 0 and not os.path.exists(out_dir):
291
+ os.makedirs(out_dir)
292
+ cv2.imwrite(out_fname, image_vis)
293
+ logger.info(f"Output saved to {out_fname}")
294
+ context["entry_idx"] += 1
295
+
296
+ @classmethod
297
+ def postexecute(cls: type, context: Dict[str, Any]):
298
+ pass
299
+ # python ./apply_net.py show ./configs/densepose_rcnn_R_50_FPN_s1x.yaml https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl /home/alin0222/DressCode/upper_body/images dp_segm -v --opts MODEL.DEVICE cpu
300
+
301
+ @classmethod
302
+ def _get_out_fname(cls: type, entry_idx: int, fname_base: str):
303
+ base, ext = os.path.splitext(fname_base)
304
+ return base + ".{0:04d}".format(entry_idx) + ext
305
+
306
+ @classmethod
307
+ def create_context(cls: type, args: argparse.Namespace, cfg: CfgNode) -> Dict[str, Any]:
308
+ vis_specs = args.visualizations.split(",")
309
+ visualizers = []
310
+ extractors = []
311
+ for vis_spec in vis_specs:
312
+ texture_atlas = get_texture_atlas(args.texture_atlas)
313
+ texture_atlases_dict = get_texture_atlases(args.texture_atlases_map)
314
+ vis = cls.VISUALIZERS[vis_spec](
315
+ cfg=cfg,
316
+ texture_atlas=texture_atlas,
317
+ texture_atlases_dict=texture_atlases_dict,
318
+ )
319
+ visualizers.append(vis)
320
+ extractor = create_extractor(vis)
321
+ extractors.append(extractor)
322
+ visualizer = CompoundVisualizer(visualizers)
323
+ extractor = CompoundExtractor(extractors)
324
+ context = {
325
+ "extractor": extractor,
326
+ "visualizer": visualizer,
327
+ "out_fname": args.output,
328
+ "entry_idx": 0,
329
+ }
330
+ return context
331
+
332
+
333
+ def create_argument_parser() -> argparse.ArgumentParser:
334
+ parser = argparse.ArgumentParser(
335
+ description=DOC,
336
+ formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=120),
337
+ )
338
+ parser.set_defaults(func=lambda _: parser.print_help(sys.stdout))
339
+ subparsers = parser.add_subparsers(title="Actions")
340
+ for _, action in _ACTION_REGISTRY.items():
341
+ action.add_parser(subparsers)
342
+ return parser
343
+
344
+
345
+ def main():
346
+ parser = create_argument_parser()
347
+ args = parser.parse_args()
348
+ verbosity = getattr(args, "verbosity", None)
349
+ global logger
350
+ logger = setup_logger(name=LOGGER_NAME)
351
+ logger.setLevel(verbosity_to_level(verbosity))
352
+ args.func(args)
353
+
354
+
355
+ if __name__ == "__main__":
356
+ main()
357
+
358
+
359
+ # python ./apply_net.py show ./configs/densepose_rcnn_R_50_FPN_s1x.yaml https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl /home/alin0222/Dresscode/dresses/humanonly dp_segm -v --opts MODEL.DEVICE cuda
requirements.txt CHANGED
@@ -1,23 +1,23 @@
1
- transformers==4.36.2
2
- torch==2.0.1
3
- torchvision==0.15.2
4
- torchaudio==2.0.2
5
- numpy==1.24.4
6
- scipy==1.10.1
7
- scikit-image==0.21.0
8
- opencv-python==4.7.0.72
9
- pillow==9.4.0
10
- diffusers==0.25.0
11
- transformers==4.36.2
12
- accelerate==0.26.1
13
- matplotlib==3.7.4
14
- tqdm==4.64.1
15
- config==0.5.1
16
- einops==0.7.0
17
- onnxruntime==1.16.2
18
- basicsr
19
- av
20
- fvcore
21
- cloudpickle
22
- omegaconf
23
  pycocotools
 
1
+ transformers==4.36.2
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ torchaudio==2.0.2
5
+ numpy==1.24.4
6
+ scipy==1.10.1
7
+ scikit-image==0.21.0
8
+ opencv-python==4.7.0.72
9
+ pillow==9.4.0
10
+ diffusers==0.25.0
11
+ transformers==4.36.2
12
+ accelerate==0.26.1
13
+ matplotlib==3.7.4
14
+ tqdm==4.64.1
15
+ config==0.5.1
16
+ einops==0.7.0
17
+ onnxruntime==1.16.2
18
+ basicsr
19
+ av
20
+ fvcore
21
+ cloudpickle
22
+ omegaconf
23
  pycocotools
utils_mask.py CHANGED
@@ -1,167 +1,167 @@
1
- import numpy as np
2
- import cv2
3
- from PIL import Image, ImageDraw
4
-
5
- label_map = {
6
- "background": 0,
7
- "hat": 1,
8
- "hair": 2,
9
- "sunglasses": 3,
10
- "upper_clothes": 4,
11
- "skirt": 5,
12
- "pants": 6,
13
- "dress": 7,
14
- "belt": 8,
15
- "left_shoe": 9,
16
- "right_shoe": 10,
17
- "head": 11,
18
- "left_leg": 12,
19
- "right_leg": 13,
20
- "left_arm": 14,
21
- "right_arm": 15,
22
- "bag": 16,
23
- "scarf": 17,
24
- }
25
-
26
- def extend_arm_mask(wrist, elbow, scale):
27
- wrist = elbow + scale * (wrist - elbow)
28
- return wrist
29
-
30
- def hole_fill(img):
31
- img = np.pad(img[1:-1, 1:-1], pad_width = 1, mode = 'constant', constant_values=0)
32
- img_copy = img.copy()
33
- mask = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=np.uint8)
34
-
35
- cv2.floodFill(img, mask, (0, 0), 255)
36
- img_inverse = cv2.bitwise_not(img)
37
- dst = cv2.bitwise_or(img_copy, img_inverse)
38
- return dst
39
-
40
- def refine_mask(mask):
41
- contours, hierarchy = cv2.findContours(mask.astype(np.uint8),
42
- cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
43
- area = []
44
- for j in range(len(contours)):
45
- a_d = cv2.contourArea(contours[j], True)
46
- area.append(abs(a_d))
47
- refine_mask = np.zeros_like(mask).astype(np.uint8)
48
- if len(area) != 0:
49
- i = area.index(max(area))
50
- cv2.drawContours(refine_mask, contours, i, color=255, thickness=-1)
51
-
52
- return refine_mask
53
-
54
- def get_mask_location(model_type, category, model_parse: Image.Image, keypoint: dict, width=384,height=512):
55
- im_parse = model_parse.resize((width, height), Image.NEAREST)
56
- parse_array = np.array(im_parse)
57
-
58
- if model_type == 'hd':
59
- arm_width = 60
60
- elif model_type == 'dc':
61
- arm_width = 45
62
- else:
63
- raise ValueError("model_type must be \'hd\' or \'dc\'!")
64
-
65
- parse_head = (parse_array == 1).astype(np.float32) + \
66
- (parse_array == 3).astype(np.float32) + \
67
- (parse_array == 11).astype(np.float32)
68
-
69
- parser_mask_fixed = (parse_array == label_map["left_shoe"]).astype(np.float32) + \
70
- (parse_array == label_map["right_shoe"]).astype(np.float32) + \
71
- (parse_array == label_map["hat"]).astype(np.float32) + \
72
- (parse_array == label_map["sunglasses"]).astype(np.float32) + \
73
- (parse_array == label_map["bag"]).astype(np.float32)
74
-
75
- parser_mask_changeable = (parse_array == label_map["background"]).astype(np.float32)
76
-
77
- arms_left = (parse_array == 14).astype(np.float32)
78
- arms_right = (parse_array == 15).astype(np.float32)
79
-
80
- if category == 'dresses':
81
- parse_mask = (parse_array == 7).astype(np.float32) + \
82
- (parse_array == 4).astype(np.float32) + \
83
- (parse_array == 5).astype(np.float32) + \
84
- (parse_array == 6).astype(np.float32)
85
-
86
- parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
87
-
88
- elif category == 'upper_body':
89
- parse_mask = (parse_array == 4).astype(np.float32) + (parse_array == 7).astype(np.float32)
90
- parser_mask_fixed_lower_cloth = (parse_array == label_map["skirt"]).astype(np.float32) + \
91
- (parse_array == label_map["pants"]).astype(np.float32)
92
- parser_mask_fixed += parser_mask_fixed_lower_cloth
93
- parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
94
- elif category == 'lower_body':
95
- parse_mask = (parse_array == 6).astype(np.float32) + \
96
- (parse_array == 12).astype(np.float32) + \
97
- (parse_array == 13).astype(np.float32) + \
98
- (parse_array == 5).astype(np.float32)
99
- parser_mask_fixed += (parse_array == label_map["upper_clothes"]).astype(np.float32) + \
100
- (parse_array == 14).astype(np.float32) + \
101
- (parse_array == 15).astype(np.float32)
102
- parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
103
- else:
104
- raise NotImplementedError
105
-
106
- # Load pose points
107
- pose_data = keypoint["pose_keypoints_2d"]
108
- pose_data = np.array(pose_data)
109
- pose_data = pose_data.reshape((-1, 2))
110
-
111
- im_arms_left = Image.new('L', (width, height))
112
- im_arms_right = Image.new('L', (width, height))
113
- arms_draw_left = ImageDraw.Draw(im_arms_left)
114
- arms_draw_right = ImageDraw.Draw(im_arms_right)
115
- if category == 'dresses' or category == 'upper_body':
116
- shoulder_right = np.multiply(tuple(pose_data[2][:2]), height / 512.0)
117
- shoulder_left = np.multiply(tuple(pose_data[5][:2]), height / 512.0)
118
- elbow_right = np.multiply(tuple(pose_data[3][:2]), height / 512.0)
119
- elbow_left = np.multiply(tuple(pose_data[6][:2]), height / 512.0)
120
- wrist_right = np.multiply(tuple(pose_data[4][:2]), height / 512.0)
121
- wrist_left = np.multiply(tuple(pose_data[7][:2]), height / 512.0)
122
- ARM_LINE_WIDTH = int(arm_width / 512 * height)
123
- size_left = [shoulder_left[0] - ARM_LINE_WIDTH // 2, shoulder_left[1] - ARM_LINE_WIDTH // 2, shoulder_left[0] + ARM_LINE_WIDTH // 2, shoulder_left[1] + ARM_LINE_WIDTH // 2]
124
- size_right = [shoulder_right[0] - ARM_LINE_WIDTH // 2, shoulder_right[1] - ARM_LINE_WIDTH // 2, shoulder_right[0] + ARM_LINE_WIDTH // 2,
125
- shoulder_right[1] + ARM_LINE_WIDTH // 2]
126
-
127
-
128
- if wrist_right[0] <= 1. and wrist_right[1] <= 1.:
129
- im_arms_right = arms_right
130
- else:
131
- wrist_right = extend_arm_mask(wrist_right, elbow_right, 1.2)
132
- arms_draw_right.line(np.concatenate((shoulder_right, elbow_right, wrist_right)).astype(np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve')
133
- arms_draw_right.arc(size_right, 0, 360, 'white', ARM_LINE_WIDTH // 2)
134
-
135
- if wrist_left[0] <= 1. and wrist_left[1] <= 1.:
136
- im_arms_left = arms_left
137
- else:
138
- wrist_left = extend_arm_mask(wrist_left, elbow_left, 1.2)
139
- arms_draw_left.line(np.concatenate((wrist_left, elbow_left, shoulder_left)).astype(np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve')
140
- arms_draw_left.arc(size_left, 0, 360, 'white', ARM_LINE_WIDTH // 2)
141
-
142
- hands_left = np.logical_and(np.logical_not(im_arms_left), arms_left)
143
- hands_right = np.logical_and(np.logical_not(im_arms_right), arms_right)
144
- parser_mask_fixed += hands_left + hands_right
145
-
146
- parser_mask_fixed = np.logical_or(parser_mask_fixed, parse_head)
147
- parse_mask = cv2.dilate(parse_mask, np.ones((5, 5), np.uint16), iterations=5)
148
- if category == 'dresses' or category == 'upper_body':
149
- neck_mask = (parse_array == 18).astype(np.float32)
150
- neck_mask = cv2.dilate(neck_mask, np.ones((5, 5), np.uint16), iterations=1)
151
- neck_mask = np.logical_and(neck_mask, np.logical_not(parse_head))
152
- parse_mask = np.logical_or(parse_mask, neck_mask)
153
- arm_mask = cv2.dilate(np.logical_or(im_arms_left, im_arms_right).astype('float32'), np.ones((5, 5), np.uint16), iterations=4)
154
- parse_mask += np.logical_or(parse_mask, arm_mask)
155
-
156
- parse_mask = np.logical_and(parser_mask_changeable, np.logical_not(parse_mask))
157
-
158
- parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed)
159
- inpaint_mask = 1 - parse_mask_total
160
- img = np.where(inpaint_mask, 255, 0)
161
- dst = hole_fill(img.astype(np.uint8))
162
- dst = refine_mask(dst)
163
- inpaint_mask = dst / 255 * 1
164
- mask = Image.fromarray(inpaint_mask.astype(np.uint8) * 255)
165
- mask_gray = Image.fromarray(inpaint_mask.astype(np.uint8) * 127)
166
-
167
- return mask, mask_gray
 
1
+ import numpy as np
2
+ import cv2
3
+ from PIL import Image, ImageDraw
4
+
5
+ label_map = {
6
+ "background": 0,
7
+ "hat": 1,
8
+ "hair": 2,
9
+ "sunglasses": 3,
10
+ "upper_clothes": 4,
11
+ "skirt": 5,
12
+ "pants": 6,
13
+ "dress": 7,
14
+ "belt": 8,
15
+ "left_shoe": 9,
16
+ "right_shoe": 10,
17
+ "head": 11,
18
+ "left_leg": 12,
19
+ "right_leg": 13,
20
+ "left_arm": 14,
21
+ "right_arm": 15,
22
+ "bag": 16,
23
+ "scarf": 17,
24
+ }
25
+
26
+ def extend_arm_mask(wrist, elbow, scale):
27
+ wrist = elbow + scale * (wrist - elbow)
28
+ return wrist
29
+
30
+ def hole_fill(img):
31
+ img = np.pad(img[1:-1, 1:-1], pad_width = 1, mode = 'constant', constant_values=0)
32
+ img_copy = img.copy()
33
+ mask = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=np.uint8)
34
+
35
+ cv2.floodFill(img, mask, (0, 0), 255)
36
+ img_inverse = cv2.bitwise_not(img)
37
+ dst = cv2.bitwise_or(img_copy, img_inverse)
38
+ return dst
39
+
40
+ def refine_mask(mask):
41
+ contours, hierarchy = cv2.findContours(mask.astype(np.uint8),
42
+ cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
43
+ area = []
44
+ for j in range(len(contours)):
45
+ a_d = cv2.contourArea(contours[j], True)
46
+ area.append(abs(a_d))
47
+ refine_mask = np.zeros_like(mask).astype(np.uint8)
48
+ if len(area) != 0:
49
+ i = area.index(max(area))
50
+ cv2.drawContours(refine_mask, contours, i, color=255, thickness=-1)
51
+
52
+ return refine_mask
53
+
54
+ def get_mask_location(model_type, category, model_parse: Image.Image, keypoint: dict, width=384,height=512):
55
+ im_parse = model_parse.resize((width, height), Image.NEAREST)
56
+ parse_array = np.array(im_parse)
57
+
58
+ if model_type == 'hd':
59
+ arm_width = 60
60
+ elif model_type == 'dc':
61
+ arm_width = 45
62
+ else:
63
+ raise ValueError("model_type must be \'hd\' or \'dc\'!")
64
+
65
+ parse_head = (parse_array == 1).astype(np.float32) + \
66
+ (parse_array == 3).astype(np.float32) + \
67
+ (parse_array == 11).astype(np.float32)
68
+
69
+ parser_mask_fixed = (parse_array == label_map["left_shoe"]).astype(np.float32) + \
70
+ (parse_array == label_map["right_shoe"]).astype(np.float32) + \
71
+ (parse_array == label_map["hat"]).astype(np.float32) + \
72
+ (parse_array == label_map["sunglasses"]).astype(np.float32) + \
73
+ (parse_array == label_map["bag"]).astype(np.float32)
74
+
75
+ parser_mask_changeable = (parse_array == label_map["background"]).astype(np.float32)
76
+
77
+ arms_left = (parse_array == 14).astype(np.float32)
78
+ arms_right = (parse_array == 15).astype(np.float32)
79
+
80
+ if category == 'dresses':
81
+ parse_mask = (parse_array == 7).astype(np.float32) + \
82
+ (parse_array == 4).astype(np.float32) + \
83
+ (parse_array == 5).astype(np.float32) + \
84
+ (parse_array == 6).astype(np.float32)
85
+
86
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
87
+
88
+ elif category == 'upper_body':
89
+ parse_mask = (parse_array == 4).astype(np.float32) + (parse_array == 7).astype(np.float32)
90
+ parser_mask_fixed_lower_cloth = (parse_array == label_map["skirt"]).astype(np.float32) + \
91
+ (parse_array == label_map["pants"]).astype(np.float32)
92
+ parser_mask_fixed += parser_mask_fixed_lower_cloth
93
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
94
+ elif category == 'lower_body':
95
+ parse_mask = (parse_array == 6).astype(np.float32) + \
96
+ (parse_array == 12).astype(np.float32) + \
97
+ (parse_array == 13).astype(np.float32) + \
98
+ (parse_array == 5).astype(np.float32)
99
+ parser_mask_fixed += (parse_array == label_map["upper_clothes"]).astype(np.float32) + \
100
+ (parse_array == 14).astype(np.float32) + \
101
+ (parse_array == 15).astype(np.float32)
102
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
103
+ else:
104
+ raise NotImplementedError
105
+
106
+ # Load pose points
107
+ pose_data = keypoint["pose_keypoints_2d"]
108
+ pose_data = np.array(pose_data)
109
+ pose_data = pose_data.reshape((-1, 2))
110
+
111
+ im_arms_left = Image.new('L', (width, height))
112
+ im_arms_right = Image.new('L', (width, height))
113
+ arms_draw_left = ImageDraw.Draw(im_arms_left)
114
+ arms_draw_right = ImageDraw.Draw(im_arms_right)
115
+ if category == 'dresses' or category == 'upper_body':
116
+ shoulder_right = np.multiply(tuple(pose_data[2][:2]), height / 512.0)
117
+ shoulder_left = np.multiply(tuple(pose_data[5][:2]), height / 512.0)
118
+ elbow_right = np.multiply(tuple(pose_data[3][:2]), height / 512.0)
119
+ elbow_left = np.multiply(tuple(pose_data[6][:2]), height / 512.0)
120
+ wrist_right = np.multiply(tuple(pose_data[4][:2]), height / 512.0)
121
+ wrist_left = np.multiply(tuple(pose_data[7][:2]), height / 512.0)
122
+ ARM_LINE_WIDTH = int(arm_width / 512 * height)
123
+ size_left = [shoulder_left[0] - ARM_LINE_WIDTH // 2, shoulder_left[1] - ARM_LINE_WIDTH // 2, shoulder_left[0] + ARM_LINE_WIDTH // 2, shoulder_left[1] + ARM_LINE_WIDTH // 2]
124
+ size_right = [shoulder_right[0] - ARM_LINE_WIDTH // 2, shoulder_right[1] - ARM_LINE_WIDTH // 2, shoulder_right[0] + ARM_LINE_WIDTH // 2,
125
+ shoulder_right[1] + ARM_LINE_WIDTH // 2]
126
+
127
+
128
+ if wrist_right[0] <= 1. and wrist_right[1] <= 1.:
129
+ im_arms_right = arms_right
130
+ else:
131
+ wrist_right = extend_arm_mask(wrist_right, elbow_right, 1.2)
132
+ arms_draw_right.line(np.concatenate((shoulder_right, elbow_right, wrist_right)).astype(np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve')
133
+ arms_draw_right.arc(size_right, 0, 360, 'white', ARM_LINE_WIDTH // 2)
134
+
135
+ if wrist_left[0] <= 1. and wrist_left[1] <= 1.:
136
+ im_arms_left = arms_left
137
+ else:
138
+ wrist_left = extend_arm_mask(wrist_left, elbow_left, 1.2)
139
+ arms_draw_left.line(np.concatenate((wrist_left, elbow_left, shoulder_left)).astype(np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve')
140
+ arms_draw_left.arc(size_left, 0, 360, 'white', ARM_LINE_WIDTH // 2)
141
+
142
+ hands_left = np.logical_and(np.logical_not(im_arms_left), arms_left)
143
+ hands_right = np.logical_and(np.logical_not(im_arms_right), arms_right)
144
+ parser_mask_fixed += hands_left + hands_right
145
+
146
+ parser_mask_fixed = np.logical_or(parser_mask_fixed, parse_head)
147
+ parse_mask = cv2.dilate(parse_mask, np.ones((5, 5), np.uint16), iterations=5)
148
+ if category == 'dresses' or category == 'upper_body':
149
+ neck_mask = (parse_array == 18).astype(np.float32)
150
+ neck_mask = cv2.dilate(neck_mask, np.ones((5, 5), np.uint16), iterations=1)
151
+ neck_mask = np.logical_and(neck_mask, np.logical_not(parse_head))
152
+ parse_mask = np.logical_or(parse_mask, neck_mask)
153
+ arm_mask = cv2.dilate(np.logical_or(im_arms_left, im_arms_right).astype('float32'), np.ones((5, 5), np.uint16), iterations=4)
154
+ parse_mask += np.logical_or(parse_mask, arm_mask)
155
+
156
+ parse_mask = np.logical_and(parser_mask_changeable, np.logical_not(parse_mask))
157
+
158
+ parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed)
159
+ inpaint_mask = 1 - parse_mask_total
160
+ img = np.where(inpaint_mask, 255, 0)
161
+ dst = hole_fill(img.astype(np.uint8))
162
+ dst = refine_mask(dst)
163
+ inpaint_mask = dst / 255 * 1
164
+ mask = Image.fromarray(inpaint_mask.astype(np.uint8) * 255)
165
+ mask_gray = Image.fromarray(inpaint_mask.astype(np.uint8) * 127)
166
+
167
+ return mask, mask_gray