Upload post_train.py with huggingface_hub
Browse files- post_train.py +191 -0
post_train.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from models.resnet import resnet18, resnet34, resnet50
|
4 |
+
from models.openmax import OpenMax
|
5 |
+
from models.metamax import MetaMax
|
6 |
+
from train import GameDataset
|
7 |
+
from utils.eval_utils import evaluate_openmax, evaluate_metamax
|
8 |
+
from torchvision import transforms
|
9 |
+
from utils.data_stats import load_dataset_stats
|
10 |
+
from pprint import pprint
|
11 |
+
|
12 |
+
def prepare_data_and_model(model_path='models/best_model.pth', model_type='resnet18', batch_size=400):
|
13 |
+
"""准备数据和模型"""
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
|
16 |
+
# 加载数据集统计信息和准备数据
|
17 |
+
mean, std = load_dataset_stats()
|
18 |
+
transform = transforms.Compose([
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Normalize(mean=mean, std=std)
|
21 |
+
])
|
22 |
+
|
23 |
+
# 加载训练集和验证集
|
24 |
+
train_dataset = GameDataset('jk_zfls/round0_train', num_labels=20, transform=transform)
|
25 |
+
val_dataset = GameDataset('jk_zfls/round0_eval', num_labels=21, transform=transform)
|
26 |
+
|
27 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
|
28 |
+
num_workers=4, pin_memory=True)
|
29 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
|
30 |
+
num_workers=4, pin_memory=True)
|
31 |
+
|
32 |
+
# 加载预训练模型
|
33 |
+
if model_type == 'resnet18':
|
34 |
+
model = resnet18(num_classes=20)
|
35 |
+
elif model_type == 'resnet34':
|
36 |
+
model = resnet34(num_classes=20)
|
37 |
+
elif model_type == 'resnet50':
|
38 |
+
model = resnet50(num_classes=20)
|
39 |
+
|
40 |
+
checkpoint = torch.load(model_path)
|
41 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
42 |
+
model = model.to(device)
|
43 |
+
model.eval()
|
44 |
+
|
45 |
+
return model, train_loader, val_loader, device
|
46 |
+
|
47 |
+
def collect_features(model, loader, device, return_logits=False):
|
48 |
+
"""收集特征和标签"""
|
49 |
+
features_list = []
|
50 |
+
logits_list = []
|
51 |
+
labels_list = []
|
52 |
+
|
53 |
+
with torch.no_grad():
|
54 |
+
for images, labels, paths in loader:
|
55 |
+
images = images.to(device)
|
56 |
+
if return_logits:
|
57 |
+
logits, features = model(images, return_features=True)
|
58 |
+
logits_list.append(logits.cpu())
|
59 |
+
else:
|
60 |
+
_, features = model(images, return_features=True)
|
61 |
+
features_list.append(features.cpu())
|
62 |
+
labels_list.append(labels)
|
63 |
+
|
64 |
+
if return_logits:
|
65 |
+
return torch.cat(features_list), torch.cat(logits_list), torch.cat(labels_list)
|
66 |
+
else:
|
67 |
+
return torch.cat(features_list), torch.cat(labels_list)
|
68 |
+
|
69 |
+
def train_openmax(features,labels, model, val_loader, device, fraction=0.2):
|
70 |
+
"""训练和评估OpenMax模型
|
71 |
+
fraction: 未知类别比例
|
72 |
+
"""
|
73 |
+
# OpenMax特定的超参数搜索空间
|
74 |
+
# alpha_range = [3, 5, 8, 12, 16, 20]
|
75 |
+
alpha_range = [12, 13, 14, 15, 16, 17, 18, 19, 20]
|
76 |
+
# tailsize_range = [10, 15, 20, 25, 30]
|
77 |
+
tailsize_range = [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
|
78 |
+
# multiplier_range = [0.5, 0.75, 1, 1.25, 1.5]
|
79 |
+
multiplier_range = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
80 |
+
|
81 |
+
best_params = {
|
82 |
+
'alpha': None,
|
83 |
+
'tailsize': None,
|
84 |
+
'multiplier': None,
|
85 |
+
'accuracy': .0,
|
86 |
+
'model': None
|
87 |
+
}
|
88 |
+
val_features, val_logits, val_labels = collect_features(model, val_loader, device, return_logits=True)
|
89 |
+
print("\n=== Training OpenMax ===")
|
90 |
+
for alpha in alpha_range:
|
91 |
+
for tailsize in tailsize_range:
|
92 |
+
print(f"\nTraining OpenMax with alpha={alpha}, tailsize={tailsize}")
|
93 |
+
|
94 |
+
openmax = OpenMax(num_classes=20, tailsize=tailsize, alpha=alpha)
|
95 |
+
openmax.fit(features, labels)
|
96 |
+
print(f"Training finished, evaluating...")
|
97 |
+
for multiplier in multiplier_range:
|
98 |
+
overall_acc, known_acc, unknown_acc = evaluate_openmax(
|
99 |
+
openmax, val_features, val_logits, val_labels, multiplier=multiplier, fraction=fraction, verbose=False
|
100 |
+
)
|
101 |
+
if overall_acc > best_params['accuracy']:
|
102 |
+
best_params.update({
|
103 |
+
'alpha': alpha,
|
104 |
+
'tailsize': tailsize,
|
105 |
+
'multiplier': multiplier,
|
106 |
+
'accuracy': overall_acc,
|
107 |
+
'model': openmax
|
108 |
+
})
|
109 |
+
print(f"\nNew best OpenMax parameters found:")
|
110 |
+
print(f"Alpha: {alpha}")
|
111 |
+
print(f"Tailsize: {tailsize}")
|
112 |
+
print(f"Multiplier: {multiplier}")
|
113 |
+
print(f"Overall Accuracy: {overall_acc:.2f}%")
|
114 |
+
print(f"Known Classes Accuracy: {known_acc:.2f}%")
|
115 |
+
print(f"Unknown Class Accuracy: {unknown_acc:.2f}%")
|
116 |
+
|
117 |
+
elif overall_acc > 95.0:
|
118 |
+
|
119 |
+
print(f"Alpha: {alpha}")
|
120 |
+
print(f"Tailsize: {tailsize}")
|
121 |
+
print(f"Multiplier: {multiplier}")
|
122 |
+
print(f"Overall Accuracy: {overall_acc:.2f}%")
|
123 |
+
print(f"Known Classes Accuracy: {known_acc:.2f}%")
|
124 |
+
print(f"Unknown Class Accuracy: {unknown_acc:.2f}%")
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
return best_params
|
129 |
+
|
130 |
+
def train_metamax(features, labels, model, val_loader, device):
|
131 |
+
"""训练和评估MetaMax模型"""
|
132 |
+
# MetaMax特定的超参数搜索空间
|
133 |
+
meta_ratio_range = [0.05, 0.1, 0.15, 0.2, 0.25]
|
134 |
+
threshold_range = [0.1, 0.2, 0.3, 0.4, 0.5]
|
135 |
+
|
136 |
+
best_params = {
|
137 |
+
'meta_ratio': None,
|
138 |
+
'threshold': None,
|
139 |
+
'accuracy': .0,
|
140 |
+
'model': None
|
141 |
+
}
|
142 |
+
|
143 |
+
print("\n=== Training MetaMax ===")
|
144 |
+
for meta_ratio in meta_ratio_range:
|
145 |
+
print(f"\nTesting MetaMax with meta_ratio={meta_ratio}")
|
146 |
+
metamax = MetaMax(num_classes=20, meta_ratio=meta_ratio)
|
147 |
+
metamax.fit(features, labels)
|
148 |
+
|
149 |
+
for threshold in threshold_range:
|
150 |
+
overall_acc, known_acc, unknown_acc = evaluate_metamax(
|
151 |
+
metamax, model, val_loader, device, threshold=threshold, verbose=False
|
152 |
+
)
|
153 |
+
|
154 |
+
if overall_acc > best_params['accuracy']:
|
155 |
+
best_params.update({
|
156 |
+
'meta_ratio': meta_ratio,
|
157 |
+
'threshold': threshold,
|
158 |
+
'accuracy': overall_acc,
|
159 |
+
'model': metamax
|
160 |
+
})
|
161 |
+
if overall_acc > 90.0:
|
162 |
+
print(f"\nNew best MetaMax parameters found:")
|
163 |
+
print(f"Meta Ratio: {meta_ratio}")
|
164 |
+
print(f"Threshold: {threshold}")
|
165 |
+
print(f"Overall Accuracy: {overall_acc:.2f}%")
|
166 |
+
print(f"Known Classes Accuracy: {known_acc:.2f}%")
|
167 |
+
print(f"Unknown Class Accuracy: {unknown_acc:.2f}%")
|
168 |
+
|
169 |
+
return best_params
|
170 |
+
|
171 |
+
if __name__ == '__main__':
|
172 |
+
# 准备数据和模型
|
173 |
+
model, train_loader, val_loader, device = prepare_data_and_model(model_path='models/resnet50_99.92.pth', model_type='resnet50', batch_size=128)
|
174 |
+
|
175 |
+
# 收集特征
|
176 |
+
features, labels = collect_features(model, train_loader, device, return_logits=False)
|
177 |
+
|
178 |
+
# 训练OpenMax
|
179 |
+
best_openmax_params = train_openmax(features, labels, model, val_loader, device)
|
180 |
+
print("\nSaving OpenMax model...")
|
181 |
+
pprint(best_openmax_params)
|
182 |
+
torch.save(best_openmax_params['model'], f'models/resnet50_openmax_{best_openmax_params["accuracy"]:.2f}.pth')
|
183 |
+
print(f"OpenMax model saved to models/resnet50_openmax_{best_openmax_params['accuracy']:.2f}.pth")
|
184 |
+
|
185 |
+
# 训练MetaMax
|
186 |
+
# best_metamax_params = train_metamax(features, labels, model, val_loader, device)
|
187 |
+
# print("\nSaving MetaMax model...")
|
188 |
+
# pprint(best_metamax_params)
|
189 |
+
# torch.save(best_metamax_params['model'], 'models/best_metamax.pth')
|
190 |
+
# print(f"MetaMax model saved to models/best_metamax.pth")
|
191 |
+
|