Leffa / SCHP /__init__.py
franciszzj's picture
init code
b213d84
raw
history blame
7.51 kB
from collections import OrderedDict
import cv2
import numpy as np
import torch
from PIL import Image
from SCHP import networks
from SCHP.utils.transforms import get_affine_transform, transform_logits
from torchvision import transforms
def get_palette(num_cls):
"""Returns the color map for visualizing the segmentation mask.
Args:
num_cls: Number of classes
Returns:
The color map
"""
n = num_cls
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
i += 1
lab >>= 3
return palette
dataset_settings = {
"lip": {
"input_size": [473, 473],
"num_classes": 20,
"label": [
"Background",
"Hat",
"Hair",
"Glove",
"Sunglasses",
"Upper-clothes",
"Dress",
"Coat",
"Socks",
"Pants",
"Jumpsuits",
"Scarf",
"Skirt",
"Face",
"Left-arm",
"Right-arm",
"Left-leg",
"Right-leg",
"Left-shoe",
"Right-shoe",
],
},
"atr": {
"input_size": [512, 512],
"num_classes": 18,
"label": [
"Background",
"Hat",
"Hair",
"Sunglasses",
"Upper-clothes",
"Skirt",
"Pants",
"Dress",
"Belt",
"Left-shoe",
"Right-shoe",
"Face",
"Left-leg",
"Right-leg",
"Left-arm",
"Right-arm",
"Bag",
"Scarf",
],
},
"pascal": {
"input_size": [512, 512],
"num_classes": 7,
"label": [
"Background",
"Head",
"Torso",
"Upper Arms",
"Lower Arms",
"Upper Legs",
"Lower Legs",
],
},
}
class SCHP:
def __init__(self, ckpt_path, device):
dataset_type = None
if "lip" in ckpt_path:
dataset_type = "lip"
elif "atr" in ckpt_path:
dataset_type = "atr"
elif "pascal" in ckpt_path:
dataset_type = "pascal"
assert dataset_type is not None, "Dataset type not found in checkpoint path"
self.device = device
self.num_classes = dataset_settings[dataset_type]["num_classes"]
self.input_size = dataset_settings[dataset_type]["input_size"]
self.aspect_ratio = self.input_size[1] * 1.0 / self.input_size[0]
self.palette = get_palette(self.num_classes)
self.label = dataset_settings[dataset_type]["label"]
self.model = networks.init_model(
"resnet101", num_classes=self.num_classes, pretrained=None
).to(device)
self.load_ckpt(ckpt_path)
self.model.eval()
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]
),
]
)
self.upsample = torch.nn.Upsample(
size=self.input_size, mode="bilinear", align_corners=True
)
def load_ckpt(self, ckpt_path):
rename_map = {
"decoder.conv3.2.weight": "decoder.conv3.3.weight",
"decoder.conv3.3.weight": "decoder.conv3.4.weight",
"decoder.conv3.3.bias": "decoder.conv3.4.bias",
"decoder.conv3.3.running_mean": "decoder.conv3.4.running_mean",
"decoder.conv3.3.running_var": "decoder.conv3.4.running_var",
"fushion.3.weight": "fushion.4.weight",
"fushion.3.bias": "fushion.4.bias",
}
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
new_state_dict_ = OrderedDict()
for k, v in list(new_state_dict.items()):
if k in rename_map:
new_state_dict_[rename_map[k]] = v
else:
new_state_dict_[k] = v
self.model.load_state_dict(new_state_dict_, strict=False)
def _box2cs(self, box):
x, y, w, h = box[:4]
return self._xywh2cs(x, y, w, h)
def _xywh2cs(self, x, y, w, h):
center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5
center[1] = y + h * 0.5
if w > self.aspect_ratio * h:
h = w * 1.0 / self.aspect_ratio
elif w < self.aspect_ratio * h:
w = h * self.aspect_ratio
scale = np.array([w, h], dtype=np.float32)
return center, scale
def preprocess(self, image):
if isinstance(image, str):
img = cv2.imread(image, cv2.IMREAD_COLOR)
elif isinstance(image, Image.Image):
# to cv2 format
img = np.array(image)
h, w, _ = img.shape
# Get person center and scale
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
r = 0
trans = get_affine_transform(person_center, s, r, self.input_size)
input = cv2.warpAffine(
img,
trans,
(int(self.input_size[1]), int(self.input_size[0])),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0, 0, 0),
)
input = self.transform(input).to(self.device).unsqueeze(0)
meta = {
"center": person_center,
"height": h,
"width": w,
"scale": s,
"rotation": r,
}
return input, meta
def __call__(self, image_or_path):
if isinstance(image_or_path, list):
image_list = []
meta_list = []
for image in image_or_path:
image, meta = self.preprocess(image)
image_list.append(image)
meta_list.append(meta)
image = torch.cat(image_list, dim=0)
else:
image, meta = self.preprocess(image_or_path)
meta_list = [meta]
output = self.model(image)
# upsample_outputs = self.upsample(output[0][-1])
upsample_outputs = self.upsample(output)
upsample_outputs = upsample_outputs.permute(0, 2, 3, 1) # BCHW -> BHWC
output_img_list = []
for upsample_output, meta in zip(upsample_outputs, meta_list):
c, s, w, h = meta["center"], meta["scale"], meta["width"], meta["height"]
logits_result = transform_logits(
upsample_output.data.cpu().numpy(),
c,
s,
w,
h,
input_size=self.input_size,
)
parsing_result = np.argmax(logits_result, axis=2)
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
output_img.putpalette(self.palette)
output_img_list.append(output_img)
return output_img_list[0] if len(output_img_list) == 1 else output_img_list