JackAILab's picture
Upload 292 files
9669aec verified
raw
history blame
3.1 kB
#!/usr/bin/python
# -*- encoding: utf-8 -*-
from logger import setup_logger
from model import BiSeNet
from face_dataset import FaceMask
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.distributed as dist
import os
import os.path as osp
import logging
import time
import numpy as np
from tqdm import tqdm
import math
from PIL import Image
import torchvision.transforms as transforms
import cv2
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
# Colors for all 20 parts
part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
[255, 0, 85], [255, 0, 170],
[0, 255, 0], [85, 255, 0], [170, 255, 0],
[0, 255, 85], [0, 255, 170],
[0, 0, 255], [85, 0, 255], [170, 0, 255],
[0, 85, 255], [0, 170, 255],
[255, 255, 0], [255, 255, 85], [255, 255, 170],
[255, 0, 255], [255, 85, 255], [255, 170, 255],
[0, 255, 255], [85, 255, 255], [170, 255, 255]]
im = np.array(im)
vis_im = im.copy().astype(np.uint8)
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
num_of_class = np.max(vis_parsing_anno)
for pi in range(1, num_of_class + 1):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
# print(vis_parsing_anno_color.shape, vis_im.shape)
vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
# Save result or not
if save_im:
cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
# return vis_im
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
if not os.path.exists(respth):
os.makedirs(respth)
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.cuda()
save_pth = osp.join('res/cp', cp)
net.load_state_dict(torch.load(save_pth))
net.eval()
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
with torch.no_grad():
for image_path in os.listdir(dspth):
img = Image.open(osp.join(dspth, image_path))
image = img.resize((512, 512), Image.BILINEAR)
img = to_tensor(image)
img = torch.unsqueeze(img, 0)
img = img.cuda()
out = net(img)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))
if __name__ == "__main__":
setup_logger('./res')
evaluate()