Spaces:
Running
Running
import os | |
import cv2 | |
import argparse | |
import numpy as np | |
import torch | |
import torchvision | |
from torchvision import datasets, transforms | |
from torch.autograd import Variable | |
from network_v0.model import PointModel | |
from datasets.hp_loader import PatchesDataset | |
from torch.utils.data import DataLoader | |
from evaluation.evaluate import evaluate_keypoint_net | |
def main(): | |
parser = argparse.ArgumentParser(description='Testing') | |
parser.add_argument('--device', default=0, type=int, help='which gpu to run on.') | |
parser.add_argument('--test_dir', required=True, type=str, help='Test data path.') | |
opt = parser.parse_args() | |
torch.manual_seed(0) | |
use_gpu = torch.cuda.is_available() | |
if use_gpu: | |
torch.cuda.set_device(opt.device) | |
# Load data in 320x240 | |
hp_dataset_320x240 = PatchesDataset(root_dir=opt.test_dir, use_color=True, output_shape=(320, 240), type='all') | |
data_loader_320x240 = DataLoader(hp_dataset_320x240, | |
batch_size=1, | |
pin_memory=False, | |
shuffle=False, | |
num_workers=4, | |
worker_init_fn=None, | |
sampler=None) | |
# Load data in 640x480 | |
hp_dataset_640x480 = PatchesDataset(root_dir=opt.test_dir, use_color=True, output_shape=(640, 480), type='all') | |
data_loader_640x480 = DataLoader(hp_dataset_640x480, | |
batch_size=1, | |
pin_memory=False, | |
shuffle=False, | |
num_workers=4, | |
worker_init_fn=None, | |
sampler=None) | |
# Load model | |
model = PointModel(is_test=True) | |
ckpt = torch.load('./checkpoints/PointModel_v0.pth') | |
model.load_state_dict(ckpt['model_state']) | |
model = model.eval() | |
if use_gpu: | |
model = model.cuda() | |
print('Evaluating in 320x240, 300 points') | |
rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net( | |
data_loader_320x240, | |
model, | |
output_shape=(320, 240), | |
top_k=300) | |
print('Repeatability: {0:.3f}'.format(rep)) | |
print('Localization Error: {0:.3f}'.format(loc)) | |
print('H-1 Accuracy: {:.3f}'.format(c1)) | |
print('H-3 Accuracy: {:.3f}'.format(c3)) | |
print('H-5 Accuracy: {:.3f}'.format(c5)) | |
print('Matching Score: {:.3f}'.format(mscore)) | |
print('\n') | |
print('Evaluating in 640x480, 1000 points') | |
rep, loc, c1, c3, c5, mscore = evaluate_keypoint_net( | |
data_loader_640x480, | |
model, | |
output_shape=(640, 480), | |
top_k=1000) | |
print('Repeatability: {0:.3f}'.format(rep)) | |
print('Localization Error: {0:.3f}'.format(loc)) | |
print('H-1 Accuracy: {:.3f}'.format(c1)) | |
print('H-3 Accuracy: {:.3f}'.format(c3)) | |
print('H-5 Accuracy: {:.3f}'.format(c5)) | |
print('Matching Score: {:.3f}'.format(mscore)) | |
print('\n') | |
if __name__ == '__main__': | |
main() | |