Facepalm0's picture
Upload test.py with huggingface_hub
0efb5b8 verified
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from models.resnet import resnet18
from models.openmax import OpenMax
from models.metamax import MetaMax
from train import GameDataset
from utils.data_stats import load_dataset_stats
from utils.eval_utils import evaluate_known_classes, evaluate_openmax, evaluate_metamax
import os
from pprint import pprint
def test_models():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载数据集统计信息
mean, std = load_dataset_stats()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
# 加载验证集
test_dataset = GameDataset('jk_zfls/round0_eval', num_labels=21, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=400, shuffle=False, num_workers=4, pin_memory=True)
# 加载基础模型
model = resnet18(num_classes=20)
checkpoint = torch.load('models/best_model_99.92_02.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
# 加载OpenMax和MetaMax模型
try:
openmax = torch.load('models/best_openmax_94.71_01.pth')
# metamax = torch.load('models/best_metamax.pth')
print("Successfully loaded OpenMax and MetaMax models")
except Exception as e:
print(f"Error loading models: {e}")
return
# 测试基础ResNet
print("\n=== Testing ResNet (Known Classes Only) ===")
_, accuracy, errors = evaluate_known_classes(model, test_loader, torch.nn.CrossEntropyLoss(), device)
print(f"Known Classes Accuracy: {accuracy:.2f}%")
if errors:
print("\nErrors in known classes:")
pprint(errors)
# 测试ResNet + OpenMax
print("\n=== Testing ResNet + OpenMax ===")
evaluate_openmax(openmax, model, test_loader, device, multiplier=0.5, fraction=0.2, verbose=True)
# 测试ResNet + MetaMax
# print("\n=== Testing ResNet + MetaMax ===")
# evaluate_metamax(metamax, model, test_loader, device, threshold=0.5, verbose=True)
if __name__ == '__main__':
test_models()