|
import torch
|
|
import matplotlib.pyplot as plt
|
|
from sklearn.manifold import TSNE
|
|
import seaborn as sns
|
|
import umap
|
|
import os
|
|
from post_train import collect_features, prepare_data_and_model
|
|
|
|
def visualize_features(features, labels, method='tsne', save_path=None, include_unknown=False):
|
|
"""
|
|
可视化特征分布
|
|
|
|
Args:
|
|
features: torch.Tensor, 特征向量
|
|
labels: torch.Tensor, 标签
|
|
method: str, 'tsne' 或 'umap'
|
|
save_path: str, 保存路径,如果为None则显示图像
|
|
include_unknown: bool, 是否包含未知类(第21类)
|
|
"""
|
|
|
|
features = features.cpu().numpy()
|
|
labels = labels.cpu().numpy()
|
|
|
|
|
|
print(f"Performing {method.upper()} dimensionality reduction...")
|
|
if method.lower() == 'tsne':
|
|
reducer = TSNE(n_components=2, random_state=42)
|
|
embedded = reducer.fit_transform(features)
|
|
else:
|
|
reducer = umap.UMAP(n_components=2, random_state=42)
|
|
embedded = reducer.fit_transform(features)
|
|
|
|
|
|
plt.close('all')
|
|
fig = plt.figure(figsize=(15, 10))
|
|
|
|
|
|
markers = ['o', 's', '^', 'D']
|
|
colors = ['#FF4B4B', '#4B4BFF', '#4BFF4B', '#FFB74B', '#B74BFF']
|
|
|
|
|
|
num_classes = 21 if include_unknown else 20
|
|
|
|
|
|
for i in range(num_classes):
|
|
marker_idx = i % len(markers)
|
|
color_idx = i % len(colors)
|
|
|
|
mask = labels == i
|
|
if i == 20:
|
|
plt.scatter(
|
|
embedded[mask, 0],
|
|
embedded[mask, 1],
|
|
c='gray',
|
|
marker='*',
|
|
s=150,
|
|
alpha=0.6,
|
|
label='Unknown',
|
|
edgecolors='white',
|
|
linewidth=0.5
|
|
)
|
|
else:
|
|
plt.scatter(
|
|
embedded[mask, 0],
|
|
embedded[mask, 1],
|
|
c=colors[color_idx],
|
|
marker=markers[marker_idx],
|
|
s=100,
|
|
alpha=0.6,
|
|
label=f'Class {i}',
|
|
edgecolors='white',
|
|
linewidth=0.5
|
|
)
|
|
|
|
plt.grid(True, linestyle='--', alpha=0.7)
|
|
plt.title(f'Feature Distribution ({method.upper()})', fontsize=14, pad=20)
|
|
|
|
|
|
plt.legend(bbox_to_anchor=(1.05, 1),
|
|
loc='upper left',
|
|
borderaxespad=0,
|
|
ncol=1,
|
|
fontsize=10)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
if save_path:
|
|
print(f"Saving plot to {save_path}")
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
plt.savefig(save_path, bbox_inches='tight', dpi=300, pad_inches=0.5)
|
|
|
|
plt.show()
|
|
plt.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
model, train_loader, val_loader, device = prepare_data_and_model(model_path='models/resnet50_99.92.pth', model_type='resnet50', batch_size=128)
|
|
|
|
|
|
print("Collecting features of training set...")
|
|
features, labels = collect_features(
|
|
model=model,
|
|
loader=train_loader,
|
|
device=device
|
|
)
|
|
|
|
print("Collecting features of validation set...")
|
|
val_features, val_labels = collect_features(
|
|
model=model,
|
|
loader=val_loader,
|
|
device=device
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Visualizing features using UMAP...")
|
|
visualize_features(
|
|
features=features,
|
|
labels=labels,
|
|
method='umap',
|
|
save_path='outputs/resnet50_train_FeatureMap.png',
|
|
include_unknown=False
|
|
)
|
|
|
|
visualize_features(
|
|
features=val_features,
|
|
labels=val_labels,
|
|
method='umap',
|
|
save_path='outputs/resnet50_val_FeatureMap.png',
|
|
include_unknown=True
|
|
)
|
|
|