|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import transforms
|
|
from models.resnet import resnet18
|
|
from models.openmax import OpenMax
|
|
from models.metamax import MetaMax
|
|
from train import GameDataset
|
|
from utils.data_stats import load_dataset_stats
|
|
from utils.eval_utils import evaluate_known_classes, evaluate_openmax, evaluate_metamax
|
|
import os
|
|
from pprint import pprint
|
|
|
|
def test_models():
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
mean, std = load_dataset_stats()
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=mean, std=std)
|
|
])
|
|
|
|
|
|
test_dataset = GameDataset('jk_zfls/round0_eval', num_labels=21, transform=transform)
|
|
test_loader = DataLoader(test_dataset, batch_size=400, shuffle=False, num_workers=4, pin_memory=True)
|
|
|
|
|
|
model = resnet18(num_classes=20)
|
|
checkpoint = torch.load('models/best_model_99.92_02.pth')
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
model = model.to(device)
|
|
model.eval()
|
|
|
|
|
|
try:
|
|
openmax = torch.load('models/best_openmax_94.71_01.pth')
|
|
|
|
print("Successfully loaded OpenMax and MetaMax models")
|
|
except Exception as e:
|
|
print(f"Error loading models: {e}")
|
|
return
|
|
|
|
|
|
print("\n=== Testing ResNet (Known Classes Only) ===")
|
|
_, accuracy, errors = evaluate_known_classes(model, test_loader, torch.nn.CrossEntropyLoss(), device)
|
|
print(f"Known Classes Accuracy: {accuracy:.2f}%")
|
|
if errors:
|
|
print("\nErrors in known classes:")
|
|
pprint(errors)
|
|
|
|
|
|
print("\n=== Testing ResNet + OpenMax ===")
|
|
evaluate_openmax(openmax, model, test_loader, device, multiplier=0.5, fraction=0.2, verbose=True)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_models()
|
|
|