ConsistentID / models /BiSeNet /face_dataset.py
JackAILab's picture
Upload 292 files
9669aec verified
raw
history blame
3.01 kB
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import os.path as osp
import os
from PIL import Image
import numpy as np
import json
import cv2
from transform import *
class FaceMask(Dataset):
def __init__(self, rootpth, cropsize=(640, 480), mode='train', *args, **kwargs):
super(FaceMask, self).__init__(*args, **kwargs)
assert mode in ('train', 'val', 'test')
self.mode = mode
self.ignore_lb = 255
self.rootpth = rootpth
self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img'))
# pre-processing
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
self.trans_train = Compose([
ColorJitter(
brightness=0.5,
contrast=0.5,
saturation=0.5),
HorizontalFlip(),
RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
RandomCrop(cropsize)
])
def __getitem__(self, idx):
impth = self.imgs[idx]
img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth))
img = img.resize((512, 512), Image.BILINEAR)
label = Image.open(osp.join(self.rootpth, 'mask', impth[:-3]+'png')).convert('P')
# print(np.unique(np.array(label)))
if self.mode == 'train':
im_lb = dict(im=img, lb=label)
im_lb = self.trans_train(im_lb)
img, label = im_lb['im'], im_lb['lb']
img = self.to_tensor(img)
label = np.array(label).astype(np.int64)[np.newaxis, :]
return img, label
def __len__(self):
return len(self.imgs)
if __name__ == "__main__":
face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
mask_path = '/home/zll/data/CelebAMask-HQ/mask'
counter = 0
total = 0
for i in range(15):
# files = os.listdir(osp.join(face_sep_mask, str(i)))
atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
for j in range(i*2000, (i+1)*2000):
mask = np.zeros((512, 512))
for l, att in enumerate(atts, 1):
total += 1
file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
path = osp.join(face_sep_mask, str(i), file_name)
if os.path.exists(path):
counter += 1
sep_mask = np.array(Image.open(path).convert('P'))
# print(np.unique(sep_mask))
mask[sep_mask == 225] = l
cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
print(j)
print(counter, total)