|
|
|
''' |
|
@File : visualizer.py |
|
@Time : 2022/04/05 11:39:33 |
|
@Author : Shilong Liu |
|
@Contact : liusl20@mail.tsinghua.edu.cn; slongliu86@gmail.com |
|
Modified from COCO evaluator |
|
''' |
|
|
|
import os, sys |
|
from textwrap import wrap |
|
import torch |
|
import numpy as np |
|
import cv2 |
|
import datetime |
|
|
|
import matplotlib.pyplot as plt |
|
from matplotlib.collections import PatchCollection |
|
from matplotlib.patches import Polygon |
|
from pycocotools import mask as maskUtils |
|
from matplotlib import transforms |
|
|
|
def renorm(img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) \ |
|
-> torch.FloatTensor: |
|
|
|
|
|
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() |
|
if img.dim() == 3: |
|
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (img.size(0), str(img.size())) |
|
img_perm = img.permute(1,2,0) |
|
mean = torch.Tensor(mean) |
|
std = torch.Tensor(std) |
|
img_res = img_perm * std + mean |
|
return img_res.permute(2,0,1) |
|
else: |
|
assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (img.size(1), str(img.size())) |
|
img_perm = img.permute(0,2,3,1) |
|
mean = torch.Tensor(mean) |
|
std = torch.Tensor(std) |
|
img_res = img_perm * std + mean |
|
return img_res.permute(0,3,1,2) |
|
|
|
class ColorMap(): |
|
def __init__(self, basergb=[255,255,0]): |
|
self.basergb = np.array(basergb) |
|
def __call__(self, attnmap): |
|
|
|
|
|
assert attnmap.dtype == np.uint8 |
|
h, w = attnmap.shape |
|
res = self.basergb.copy() |
|
res = res[None][None].repeat(h, 0).repeat(w, 1) |
|
attn1 = attnmap.copy()[..., None] |
|
res = np.concatenate((res, attn1), axis=-1).astype(np.uint8) |
|
return res |
|
|
|
|
|
class COCOVisualizer(): |
|
def __init__(self) -> None: |
|
pass |
|
|
|
def visualize(self, img, tgt, caption=None, dpi=120, savedir=None, show_in_console=True): |
|
""" |
|
img: tensor(3, H, W) |
|
tgt: make sure they are all on cpu. |
|
must have items: 'image_id', 'boxes', 'size' |
|
""" |
|
plt.figure(dpi=dpi) |
|
plt.rcParams['font.size'] = '5' |
|
ax = plt.gca() |
|
img = renorm(img).permute(1, 2, 0) |
|
ax.imshow(img) |
|
|
|
self.addtgt(tgt) |
|
if show_in_console: |
|
plt.show() |
|
|
|
if savedir is not None: |
|
if caption is None: |
|
savename = '{}/{}-{}.png'.format(savedir, int(tgt['image_id']), str(datetime.datetime.now()).replace(' ', '-')) |
|
else: |
|
savename = '{}/{}-{}-{}.png'.format(savedir, caption, int(tgt['image_id']), str(datetime.datetime.now()).replace(' ', '-')) |
|
print("savename: {}".format(savename)) |
|
os.makedirs(os.path.dirname(savename), exist_ok=True) |
|
plt.savefig(savename) |
|
plt.close() |
|
|
|
def addtgt(self, tgt): |
|
""" |
|
- tgt: dict. args: |
|
- boxes: num_boxes, 4. xywh, [0,1]. |
|
- box_label: num_boxes. |
|
""" |
|
assert 'boxes' in tgt |
|
ax = plt.gca() |
|
H, W = tgt['size'].tolist() |
|
numbox = tgt['boxes'].shape[0] |
|
|
|
color = [] |
|
polygons = [] |
|
boxes = [] |
|
for box in tgt['boxes'].cpu(): |
|
unnormbbox = box * torch.Tensor([W, H, W, H]) |
|
unnormbbox[:2] -= unnormbbox[2:] / 2 |
|
[bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist() |
|
boxes.append([bbox_x, bbox_y, bbox_w, bbox_h]) |
|
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]] |
|
np_poly = np.array(poly).reshape((4,2)) |
|
polygons.append(Polygon(np_poly)) |
|
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0] |
|
color.append(c) |
|
|
|
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1) |
|
ax.add_collection(p) |
|
p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2) |
|
ax.add_collection(p) |
|
|
|
|
|
if 'box_label' in tgt: |
|
assert len(tgt['box_label']) == numbox, f"{len(tgt['box_label'])} = {numbox}, " |
|
for idx, bl in enumerate(tgt['box_label']): |
|
_string = str(bl) |
|
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] |
|
|
|
ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': color[idx], 'alpha': 0.6, 'pad': 1}) |
|
|
|
if 'caption' in tgt: |
|
ax.set_title(tgt['caption'], wrap=True) |
|
|
|
|
|
|