Facepalm0 commited on
Commit
0efb5b8
·
verified ·
1 Parent(s): a821f69

Upload test.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test.py +60 -0
test.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torchvision import transforms
4
+ from models.resnet import resnet18
5
+ from models.openmax import OpenMax
6
+ from models.metamax import MetaMax
7
+ from train import GameDataset
8
+ from utils.data_stats import load_dataset_stats
9
+ from utils.eval_utils import evaluate_known_classes, evaluate_openmax, evaluate_metamax
10
+ import os
11
+ from pprint import pprint
12
+
13
+ def test_models():
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # 加载数据集统计信息
17
+ mean, std = load_dataset_stats()
18
+ transform = transforms.Compose([
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=mean, std=std)
21
+ ])
22
+
23
+ # 加载验证集
24
+ test_dataset = GameDataset('jk_zfls/round0_eval', num_labels=21, transform=transform)
25
+ test_loader = DataLoader(test_dataset, batch_size=400, shuffle=False, num_workers=4, pin_memory=True)
26
+
27
+ # 加载基础模型
28
+ model = resnet18(num_classes=20)
29
+ checkpoint = torch.load('models/best_model_99.92_02.pth')
30
+ model.load_state_dict(checkpoint['model_state_dict'])
31
+ model = model.to(device)
32
+ model.eval()
33
+
34
+ # 加载OpenMax和MetaMax模型
35
+ try:
36
+ openmax = torch.load('models/best_openmax_94.71_01.pth')
37
+ # metamax = torch.load('models/best_metamax.pth')
38
+ print("Successfully loaded OpenMax and MetaMax models")
39
+ except Exception as e:
40
+ print(f"Error loading models: {e}")
41
+ return
42
+
43
+ # 测试基础ResNet
44
+ print("\n=== Testing ResNet (Known Classes Only) ===")
45
+ _, accuracy, errors = evaluate_known_classes(model, test_loader, torch.nn.CrossEntropyLoss(), device)
46
+ print(f"Known Classes Accuracy: {accuracy:.2f}%")
47
+ if errors:
48
+ print("\nErrors in known classes:")
49
+ pprint(errors)
50
+
51
+ # 测试ResNet + OpenMax
52
+ print("\n=== Testing ResNet + OpenMax ===")
53
+ evaluate_openmax(openmax, model, test_loader, device, multiplier=0.5, fraction=0.2, verbose=True)
54
+
55
+ # 测试ResNet + MetaMax
56
+ # print("\n=== Testing ResNet + MetaMax ===")
57
+ # evaluate_metamax(metamax, model, test_loader, device, threshold=0.5, verbose=True)
58
+
59
+ if __name__ == '__main__':
60
+ test_models()