Facepalm0 commited on
Commit
b034760
·
verified ·
1 Parent(s): d52d35f

Upload post_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. post_train.py +191 -0
post_train.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from models.resnet import resnet18, resnet34, resnet50
4
+ from models.openmax import OpenMax
5
+ from models.metamax import MetaMax
6
+ from train import GameDataset
7
+ from utils.eval_utils import evaluate_openmax, evaluate_metamax
8
+ from torchvision import transforms
9
+ from utils.data_stats import load_dataset_stats
10
+ from pprint import pprint
11
+
12
+ def prepare_data_and_model(model_path='models/best_model.pth', model_type='resnet18', batch_size=400):
13
+ """准备数据和模型"""
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # 加载数据集统计信息和准备数据
17
+ mean, std = load_dataset_stats()
18
+ transform = transforms.Compose([
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=mean, std=std)
21
+ ])
22
+
23
+ # 加载训练集和验证集
24
+ train_dataset = GameDataset('jk_zfls/round0_train', num_labels=20, transform=transform)
25
+ val_dataset = GameDataset('jk_zfls/round0_eval', num_labels=21, transform=transform)
26
+
27
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
28
+ num_workers=4, pin_memory=True)
29
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
30
+ num_workers=4, pin_memory=True)
31
+
32
+ # 加载预训练模型
33
+ if model_type == 'resnet18':
34
+ model = resnet18(num_classes=20)
35
+ elif model_type == 'resnet34':
36
+ model = resnet34(num_classes=20)
37
+ elif model_type == 'resnet50':
38
+ model = resnet50(num_classes=20)
39
+
40
+ checkpoint = torch.load(model_path)
41
+ model.load_state_dict(checkpoint['model_state_dict'])
42
+ model = model.to(device)
43
+ model.eval()
44
+
45
+ return model, train_loader, val_loader, device
46
+
47
+ def collect_features(model, loader, device, return_logits=False):
48
+ """收集特征和标签"""
49
+ features_list = []
50
+ logits_list = []
51
+ labels_list = []
52
+
53
+ with torch.no_grad():
54
+ for images, labels, paths in loader:
55
+ images = images.to(device)
56
+ if return_logits:
57
+ logits, features = model(images, return_features=True)
58
+ logits_list.append(logits.cpu())
59
+ else:
60
+ _, features = model(images, return_features=True)
61
+ features_list.append(features.cpu())
62
+ labels_list.append(labels)
63
+
64
+ if return_logits:
65
+ return torch.cat(features_list), torch.cat(logits_list), torch.cat(labels_list)
66
+ else:
67
+ return torch.cat(features_list), torch.cat(labels_list)
68
+
69
+ def train_openmax(features,labels, model, val_loader, device, fraction=0.2):
70
+ """训练和评估OpenMax模型
71
+ fraction: 未知类别比例
72
+ """
73
+ # OpenMax特定的超参数搜索空间
74
+ # alpha_range = [3, 5, 8, 12, 16, 20]
75
+ alpha_range = [12, 13, 14, 15, 16, 17, 18, 19, 20]
76
+ # tailsize_range = [10, 15, 20, 25, 30]
77
+ tailsize_range = [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
78
+ # multiplier_range = [0.5, 0.75, 1, 1.25, 1.5]
79
+ multiplier_range = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
80
+
81
+ best_params = {
82
+ 'alpha': None,
83
+ 'tailsize': None,
84
+ 'multiplier': None,
85
+ 'accuracy': .0,
86
+ 'model': None
87
+ }
88
+ val_features, val_logits, val_labels = collect_features(model, val_loader, device, return_logits=True)
89
+ print("\n=== Training OpenMax ===")
90
+ for alpha in alpha_range:
91
+ for tailsize in tailsize_range:
92
+ print(f"\nTraining OpenMax with alpha={alpha}, tailsize={tailsize}")
93
+
94
+ openmax = OpenMax(num_classes=20, tailsize=tailsize, alpha=alpha)
95
+ openmax.fit(features, labels)
96
+ print(f"Training finished, evaluating...")
97
+ for multiplier in multiplier_range:
98
+ overall_acc, known_acc, unknown_acc = evaluate_openmax(
99
+ openmax, val_features, val_logits, val_labels, multiplier=multiplier, fraction=fraction, verbose=False
100
+ )
101
+ if overall_acc > best_params['accuracy']:
102
+ best_params.update({
103
+ 'alpha': alpha,
104
+ 'tailsize': tailsize,
105
+ 'multiplier': multiplier,
106
+ 'accuracy': overall_acc,
107
+ 'model': openmax
108
+ })
109
+ print(f"\nNew best OpenMax parameters found:")
110
+ print(f"Alpha: {alpha}")
111
+ print(f"Tailsize: {tailsize}")
112
+ print(f"Multiplier: {multiplier}")
113
+ print(f"Overall Accuracy: {overall_acc:.2f}%")
114
+ print(f"Known Classes Accuracy: {known_acc:.2f}%")
115
+ print(f"Unknown Class Accuracy: {unknown_acc:.2f}%")
116
+
117
+ elif overall_acc > 95.0:
118
+
119
+ print(f"Alpha: {alpha}")
120
+ print(f"Tailsize: {tailsize}")
121
+ print(f"Multiplier: {multiplier}")
122
+ print(f"Overall Accuracy: {overall_acc:.2f}%")
123
+ print(f"Known Classes Accuracy: {known_acc:.2f}%")
124
+ print(f"Unknown Class Accuracy: {unknown_acc:.2f}%")
125
+
126
+
127
+
128
+ return best_params
129
+
130
+ def train_metamax(features, labels, model, val_loader, device):
131
+ """训练和评估MetaMax模型"""
132
+ # MetaMax特定的超参数搜索空间
133
+ meta_ratio_range = [0.05, 0.1, 0.15, 0.2, 0.25]
134
+ threshold_range = [0.1, 0.2, 0.3, 0.4, 0.5]
135
+
136
+ best_params = {
137
+ 'meta_ratio': None,
138
+ 'threshold': None,
139
+ 'accuracy': .0,
140
+ 'model': None
141
+ }
142
+
143
+ print("\n=== Training MetaMax ===")
144
+ for meta_ratio in meta_ratio_range:
145
+ print(f"\nTesting MetaMax with meta_ratio={meta_ratio}")
146
+ metamax = MetaMax(num_classes=20, meta_ratio=meta_ratio)
147
+ metamax.fit(features, labels)
148
+
149
+ for threshold in threshold_range:
150
+ overall_acc, known_acc, unknown_acc = evaluate_metamax(
151
+ metamax, model, val_loader, device, threshold=threshold, verbose=False
152
+ )
153
+
154
+ if overall_acc > best_params['accuracy']:
155
+ best_params.update({
156
+ 'meta_ratio': meta_ratio,
157
+ 'threshold': threshold,
158
+ 'accuracy': overall_acc,
159
+ 'model': metamax
160
+ })
161
+ if overall_acc > 90.0:
162
+ print(f"\nNew best MetaMax parameters found:")
163
+ print(f"Meta Ratio: {meta_ratio}")
164
+ print(f"Threshold: {threshold}")
165
+ print(f"Overall Accuracy: {overall_acc:.2f}%")
166
+ print(f"Known Classes Accuracy: {known_acc:.2f}%")
167
+ print(f"Unknown Class Accuracy: {unknown_acc:.2f}%")
168
+
169
+ return best_params
170
+
171
+ if __name__ == '__main__':
172
+ # 准备数据和模型
173
+ model, train_loader, val_loader, device = prepare_data_and_model(model_path='models/resnet50_99.92.pth', model_type='resnet50', batch_size=128)
174
+
175
+ # 收集特征
176
+ features, labels = collect_features(model, train_loader, device, return_logits=False)
177
+
178
+ # 训练OpenMax
179
+ best_openmax_params = train_openmax(features, labels, model, val_loader, device)
180
+ print("\nSaving OpenMax model...")
181
+ pprint(best_openmax_params)
182
+ torch.save(best_openmax_params['model'], f'models/resnet50_openmax_{best_openmax_params["accuracy"]:.2f}.pth')
183
+ print(f"OpenMax model saved to models/resnet50_openmax_{best_openmax_params['accuracy']:.2f}.pth")
184
+
185
+ # 训练MetaMax
186
+ # best_metamax_params = train_metamax(features, labels, model, val_loader, device)
187
+ # print("\nSaving MetaMax model...")
188
+ # pprint(best_metamax_params)
189
+ # torch.save(best_metamax_params['model'], 'models/best_metamax.pth')
190
+ # print(f"MetaMax model saved to models/best_metamax.pth")
191
+