|
import os |
|
import cv2 |
|
import argparse |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
|
|
from torchvision import datasets, transforms |
|
from torch.autograd import Variable |
|
from network_v0.model import PointModel |
|
from datasets.hp_loader import PatchesDataset |
|
from torch.utils.data import DataLoader |
|
from evaluation.evaluate import evaluate_keypoint_net |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Testing') |
|
parser.add_argument('--device', default=0, type=int, help='which gpu to run on.') |
|
parser.add_argument('--test_dir', required=True, type=str, help='Test data path.') |
|
opt = parser.parse_args() |
|
|
|
torch.manual_seed(0) |
|
use_gpu = torch.cuda.is_available() |
|
if use_gpu: |
|
torch.cuda.set_device(opt.device) |
|
|
|
|
|
hp_dataset_320x240 = PatchesDataset(root_dir=opt.test_dir, use_color=True, output_shape=(320, 240), type='all') |
|
data_loader_320x240 = DataLoader(hp_dataset_320x240, |
|
batch_size=1, |
|
pin_memory=False, |
|
shuffle=False, |
|
num_workers=4, |
|
worker_init_fn=None, |
|
sampler=None) |
|
|
|
|
|
hp_dataset_640x480 = PatchesDataset(root_dir=opt.test_dir, use_color=True, output_shape=(640, 480), type='all') |
|
data_loader_640x480 = DataLoader(hp_dataset_640x480, |
|
batch_size=1, |
|
pin_memory=False, |
|
shuffle=False, |
|
num_workers=4, |
|
worker_init_fn=None, |
|
sampler=None) |
|
|
|
|
|
model = PointModel(is_test=True) |
|
ckpt = torch.load('./checkpoints/PointModel_v0.pth') |
|
model.load_state_dict(ckpt['model_state']) |
|
model = model.eval() |
|
if use_gpu: |
|
model = model.cuda() |
|
|
|
|
|
print('Evaluating in 320x240, 300 points') |
|
rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net( |
|
data_loader_320x240, |
|
model, |
|
output_shape=(320, 240), |
|
top_k=300) |
|
|
|
print('Repeatability: {0:.3f}'.format(rep)) |
|
print('Localization Error: {0:.3f}'.format(loc)) |
|
print('H-1 Accuracy: {:.3f}'.format(c1)) |
|
print('H-3 Accuracy: {:.3f}'.format(c3)) |
|
print('H-5 Accuracy: {:.3f}'.format(c5)) |
|
print('Matching Score: {:.3f}'.format(mscore)) |
|
print('\n') |
|
|
|
print('Evaluating in 640x480, 1000 points') |
|
rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net( |
|
data_loader_640x480, |
|
model, |
|
output_shape=(640, 480), |
|
top_k=1000) |
|
|
|
print('Repeatability: {0:.3f}'.format(rep)) |
|
print('Localization Error: {0:.3f}'.format(loc)) |
|
print('H-1 Accuracy: {:.3f}'.format(c1)) |
|
print('H-3 Accuracy: {:.3f}'.format(c3)) |
|
print('H-5 Accuracy: {:.3f}'.format(c5)) |
|
print('Matching Score: {:.3f}'.format(mscore)) |
|
print('\n') |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|