|
import pdb |
|
import os |
|
import sys |
|
import tqdm |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from PIL import Image |
|
from matplotlib import pyplot as pl; pl.ion() |
|
from scipy.ndimage import uniform_filter |
|
smooth = lambda arr: uniform_filter(arr, 3) |
|
|
|
def transparent(img, alpha, cmap, **kw): |
|
from matplotlib.colors import Normalize |
|
colored_img = cmap(Normalize(clip=True,**kw)(img)) |
|
colored_img[:,:,-1] = alpha |
|
return colored_img |
|
|
|
from tools import common |
|
from tools.dataloader import norm_RGB |
|
from nets.patchnet import * |
|
from extract import NonMaxSuppression |
|
|
|
|
|
if __name__ == '__main__': |
|
import argparse |
|
parser = argparse.ArgumentParser("Visualize the patch detector and descriptor") |
|
|
|
parser.add_argument("--img", type=str, default="imgs/brooklyn.png") |
|
parser.add_argument("--resize", type=int, default=512) |
|
parser.add_argument("--out", type=str, default="viz.png") |
|
|
|
parser.add_argument("--checkpoint", type=str, required=True, help='network path') |
|
parser.add_argument("--net", type=str, default="", help='network command') |
|
|
|
parser.add_argument("--max-kpts", type=int, default=200) |
|
parser.add_argument("--reliability-thr", type=float, default=0.8) |
|
parser.add_argument("--repeatability-thr", type=float, default=0.7) |
|
parser.add_argument("--border", type=int, default=20,help='rm keypoints close to border') |
|
|
|
parser.add_argument("--gpu", type=int, nargs='+', required=True, help='-1 for CPU') |
|
parser.add_argument("--dbg", type=str, nargs='+', default=(), help='debug options') |
|
|
|
args = parser.parse_args() |
|
args.dbg = set(args.dbg) |
|
|
|
iscuda = common.torch_set_gpu(args.gpu) |
|
device = torch.device('cuda' if iscuda else 'cpu') |
|
|
|
|
|
checkpoint = torch.load(args.checkpoint, lambda a,b:a) |
|
args.net = args.net or checkpoint['net'] |
|
print("\n>> Creating net = " + args.net) |
|
net = eval(args.net) |
|
net.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()}) |
|
if iscuda: net = net.cuda() |
|
print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )") |
|
|
|
img = Image.open(args.img).convert('RGB') |
|
if args.resize: img.thumbnail((args.resize,args.resize)) |
|
img = np.asarray(img) |
|
|
|
detector = NonMaxSuppression( |
|
rel_thr = args.reliability_thr, |
|
rep_thr = args.repeatability_thr) |
|
|
|
with torch.no_grad(): |
|
print(">> computing features...") |
|
res = net(imgs=[norm_RGB(img).unsqueeze(0).to(device)]) |
|
rela = res.get('reliability') |
|
repe = res.get('repeatability') |
|
kpts = detector(**res).T[:,[1,0]] |
|
kpts = kpts[repe[0][0,0][kpts[:,1],kpts[:,0]].argsort()[-args.max_kpts:]] |
|
|
|
fig = pl.figure("viz") |
|
kw = dict(cmap=pl.cm.RdYlGn, vmax=1) |
|
crop = (slice(args.border,-args.border or 1),)*2 |
|
|
|
if 'reliability' in args.dbg: |
|
|
|
ax1 = pl.subplot(131) |
|
pl.imshow(img[crop], cmap=pl.cm.gray) |
|
pl.xticks(()); pl.yticks(()) |
|
|
|
pl.subplot(132) |
|
pl.imshow(img[crop], cmap=pl.cm.gray, alpha=0) |
|
pl.xticks(()); pl.yticks(()) |
|
|
|
x,y = kpts[:,0:2].cpu().numpy().T - args.border |
|
pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0) |
|
|
|
ax1 = pl.subplot(133) |
|
rela = rela[0][0,0].cpu().numpy() |
|
pl.imshow(rela[crop], cmap=pl.cm.RdYlGn, vmax=1, vmin=0.9) |
|
pl.xticks(()); pl.yticks(()) |
|
|
|
else: |
|
ax1 = pl.subplot(131) |
|
pl.imshow(img[crop], cmap=pl.cm.gray) |
|
pl.xticks(()); pl.yticks(()) |
|
|
|
x,y = kpts[:,0:2].cpu().numpy().T - args.border |
|
pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0) |
|
|
|
pl.subplot(132) |
|
pl.imshow(img[crop], cmap=pl.cm.gray) |
|
pl.xticks(()); pl.yticks(()) |
|
c = repe[0][0,0].cpu().numpy() |
|
pl.imshow(transparent(smooth(c)[crop], 0.5, vmin=0, **kw)) |
|
|
|
ax1 = pl.subplot(133) |
|
pl.imshow(img[crop], cmap=pl.cm.gray) |
|
pl.xticks(()); pl.yticks(()) |
|
rela = rela[0][0,0].cpu().numpy() |
|
pl.imshow(transparent(rela[crop], 0.5, vmin=0.9, **kw)) |
|
|
|
pl.gcf().set_size_inches(9, 2.73) |
|
pl.subplots_adjust(0.01,0.01,0.99,0.99,hspace=0.1) |
|
pl.savefig(args.out) |
|
pdb.set_trace() |
|
|
|
|