Ubiquant_CharacterHunter / post_train.py
Facepalm0's picture
Upload post_train.py with huggingface_hub
b034760 verified
import torch
from torch.utils.data import DataLoader
from models.resnet import resnet18, resnet34, resnet50
from models.openmax import OpenMax
from models.metamax import MetaMax
from train import GameDataset
from utils.eval_utils import evaluate_openmax, evaluate_metamax
from torchvision import transforms
from utils.data_stats import load_dataset_stats
from pprint import pprint
def prepare_data_and_model(model_path='models/best_model.pth', model_type='resnet18', batch_size=400):
"""准备数据和模型"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载数据集统计信息和准备数据
mean, std = load_dataset_stats()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
# 加载训练集和验证集
train_dataset = GameDataset('jk_zfls/round0_train', num_labels=20, transform=transform)
val_dataset = GameDataset('jk_zfls/round0_eval', num_labels=21, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
num_workers=4, pin_memory=True)
# 加载预训练模型
if model_type == 'resnet18':
model = resnet18(num_classes=20)
elif model_type == 'resnet34':
model = resnet34(num_classes=20)
elif model_type == 'resnet50':
model = resnet50(num_classes=20)
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
return model, train_loader, val_loader, device
def collect_features(model, loader, device, return_logits=False):
"""收集特征和标签"""
features_list = []
logits_list = []
labels_list = []
with torch.no_grad():
for images, labels, paths in loader:
images = images.to(device)
if return_logits:
logits, features = model(images, return_features=True)
logits_list.append(logits.cpu())
else:
_, features = model(images, return_features=True)
features_list.append(features.cpu())
labels_list.append(labels)
if return_logits:
return torch.cat(features_list), torch.cat(logits_list), torch.cat(labels_list)
else:
return torch.cat(features_list), torch.cat(labels_list)
def train_openmax(features,labels, model, val_loader, device, fraction=0.2):
"""训练和评估OpenMax模型
fraction: 未知类别比例
"""
# OpenMax特定的超参数搜索空间
# alpha_range = [3, 5, 8, 12, 16, 20]
alpha_range = [12, 13, 14, 15, 16, 17, 18, 19, 20]
# tailsize_range = [10, 15, 20, 25, 30]
tailsize_range = [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
# multiplier_range = [0.5, 0.75, 1, 1.25, 1.5]
multiplier_range = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
best_params = {
'alpha': None,
'tailsize': None,
'multiplier': None,
'accuracy': .0,
'model': None
}
val_features, val_logits, val_labels = collect_features(model, val_loader, device, return_logits=True)
print("\n=== Training OpenMax ===")
for alpha in alpha_range:
for tailsize in tailsize_range:
print(f"\nTraining OpenMax with alpha={alpha}, tailsize={tailsize}")
openmax = OpenMax(num_classes=20, tailsize=tailsize, alpha=alpha)
openmax.fit(features, labels)
print(f"Training finished, evaluating...")
for multiplier in multiplier_range:
overall_acc, known_acc, unknown_acc = evaluate_openmax(
openmax, val_features, val_logits, val_labels, multiplier=multiplier, fraction=fraction, verbose=False
)
if overall_acc > best_params['accuracy']:
best_params.update({
'alpha': alpha,
'tailsize': tailsize,
'multiplier': multiplier,
'accuracy': overall_acc,
'model': openmax
})
print(f"\nNew best OpenMax parameters found:")
print(f"Alpha: {alpha}")
print(f"Tailsize: {tailsize}")
print(f"Multiplier: {multiplier}")
print(f"Overall Accuracy: {overall_acc:.2f}%")
print(f"Known Classes Accuracy: {known_acc:.2f}%")
print(f"Unknown Class Accuracy: {unknown_acc:.2f}%")
elif overall_acc > 95.0:
print(f"Alpha: {alpha}")
print(f"Tailsize: {tailsize}")
print(f"Multiplier: {multiplier}")
print(f"Overall Accuracy: {overall_acc:.2f}%")
print(f"Known Classes Accuracy: {known_acc:.2f}%")
print(f"Unknown Class Accuracy: {unknown_acc:.2f}%")
return best_params
def train_metamax(features, labels, model, val_loader, device):
"""训练和评估MetaMax模型"""
# MetaMax特定的超参数搜索空间
meta_ratio_range = [0.05, 0.1, 0.15, 0.2, 0.25]
threshold_range = [0.1, 0.2, 0.3, 0.4, 0.5]
best_params = {
'meta_ratio': None,
'threshold': None,
'accuracy': .0,
'model': None
}
print("\n=== Training MetaMax ===")
for meta_ratio in meta_ratio_range:
print(f"\nTesting MetaMax with meta_ratio={meta_ratio}")
metamax = MetaMax(num_classes=20, meta_ratio=meta_ratio)
metamax.fit(features, labels)
for threshold in threshold_range:
overall_acc, known_acc, unknown_acc = evaluate_metamax(
metamax, model, val_loader, device, threshold=threshold, verbose=False
)
if overall_acc > best_params['accuracy']:
best_params.update({
'meta_ratio': meta_ratio,
'threshold': threshold,
'accuracy': overall_acc,
'model': metamax
})
if overall_acc > 90.0:
print(f"\nNew best MetaMax parameters found:")
print(f"Meta Ratio: {meta_ratio}")
print(f"Threshold: {threshold}")
print(f"Overall Accuracy: {overall_acc:.2f}%")
print(f"Known Classes Accuracy: {known_acc:.2f}%")
print(f"Unknown Class Accuracy: {unknown_acc:.2f}%")
return best_params
if __name__ == '__main__':
# 准备数据和模型
model, train_loader, val_loader, device = prepare_data_and_model(model_path='models/resnet50_99.92.pth', model_type='resnet50', batch_size=128)
# 收集特征
features, labels = collect_features(model, train_loader, device, return_logits=False)
# 训练OpenMax
best_openmax_params = train_openmax(features, labels, model, val_loader, device)
print("\nSaving OpenMax model...")
pprint(best_openmax_params)
torch.save(best_openmax_params['model'], f'models/resnet50_openmax_{best_openmax_params["accuracy"]:.2f}.pth')
print(f"OpenMax model saved to models/resnet50_openmax_{best_openmax_params['accuracy']:.2f}.pth")
# 训练MetaMax
# best_metamax_params = train_metamax(features, labels, model, val_loader, device)
# print("\nSaving MetaMax model...")
# pprint(best_metamax_params)
# torch.save(best_metamax_params['model'], 'models/best_metamax.pth')
# print(f"MetaMax model saved to models/best_metamax.pth")