Upload visualize_features.py with huggingface_hub
Browse files- visualize_features.py +143 -0
visualize_features.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from sklearn.manifold import TSNE
|
4 |
+
import seaborn as sns
|
5 |
+
import umap
|
6 |
+
import os
|
7 |
+
from post_train import collect_features, prepare_data_and_model
|
8 |
+
|
9 |
+
def visualize_features(features, labels, method='tsne', save_path=None, include_unknown=False):
|
10 |
+
"""
|
11 |
+
可视化特征分布
|
12 |
+
|
13 |
+
Args:
|
14 |
+
features: torch.Tensor, 特征向量
|
15 |
+
labels: torch.Tensor, 标签
|
16 |
+
method: str, 'tsne' 或 'umap'
|
17 |
+
save_path: str, 保存路径,如果为None则显示图像
|
18 |
+
include_unknown: bool, 是否包含未知类(第21类)
|
19 |
+
"""
|
20 |
+
# 转换为numpy数组
|
21 |
+
features = features.cpu().numpy()
|
22 |
+
labels = labels.cpu().numpy()
|
23 |
+
|
24 |
+
# 降维
|
25 |
+
print(f"Performing {method.upper()} dimensionality reduction...")
|
26 |
+
if method.lower() == 'tsne':
|
27 |
+
reducer = TSNE(n_components=2, random_state=42)
|
28 |
+
embedded = reducer.fit_transform(features)
|
29 |
+
else: # umap
|
30 |
+
reducer = umap.UMAP(n_components=2, random_state=42)
|
31 |
+
embedded = reducer.fit_transform(features)
|
32 |
+
|
33 |
+
# 清理之前的图像状态并创建新图形
|
34 |
+
plt.close('all') # 关闭所有图形
|
35 |
+
fig = plt.figure(figsize=(15, 10))
|
36 |
+
|
37 |
+
# 定义标记样式和颜色
|
38 |
+
markers = ['o', 's', '^', 'D'] # 圆形、方形、三角形、菱形
|
39 |
+
colors = ['#FF4B4B', '#4B4BFF', '#4BFF4B', '#FFB74B', '#B74BFF'] # 红、蓝、绿、橙、紫
|
40 |
+
|
41 |
+
# 确定要绘制的类别数量
|
42 |
+
num_classes = 21 if include_unknown else 20
|
43 |
+
|
44 |
+
# 为每个类别分配标记和颜色
|
45 |
+
for i in range(num_classes):
|
46 |
+
marker_idx = i % len(markers)
|
47 |
+
color_idx = i % len(colors)
|
48 |
+
|
49 |
+
mask = labels == i
|
50 |
+
if i == 20: # 未知类使用特殊标记
|
51 |
+
plt.scatter(
|
52 |
+
embedded[mask, 0],
|
53 |
+
embedded[mask, 1],
|
54 |
+
c='gray', # 使用灰色
|
55 |
+
marker='*', # 使用星形
|
56 |
+
s=150, # 稍微大一点
|
57 |
+
alpha=0.6,
|
58 |
+
label='Unknown',
|
59 |
+
edgecolors='white',
|
60 |
+
linewidth=0.5
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
plt.scatter(
|
64 |
+
embedded[mask, 0],
|
65 |
+
embedded[mask, 1],
|
66 |
+
c=colors[color_idx],
|
67 |
+
marker=markers[marker_idx],
|
68 |
+
s=100,
|
69 |
+
alpha=0.6,
|
70 |
+
label=f'Class {i}',
|
71 |
+
edgecolors='white',
|
72 |
+
linewidth=0.5
|
73 |
+
)
|
74 |
+
|
75 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
76 |
+
plt.title(f'Feature Distribution ({method.upper()})', fontsize=14, pad=20)
|
77 |
+
|
78 |
+
# 调整图例
|
79 |
+
plt.legend(bbox_to_anchor=(1.05, 1),
|
80 |
+
loc='upper left',
|
81 |
+
borderaxespad=0,
|
82 |
+
ncol=1, # 使用单列显示图例
|
83 |
+
fontsize=10)
|
84 |
+
|
85 |
+
# 调整布局
|
86 |
+
plt.tight_layout()
|
87 |
+
|
88 |
+
# 先保存再显示
|
89 |
+
if save_path:
|
90 |
+
print(f"Saving plot to {save_path}")
|
91 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
92 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=300, pad_inches=0.5)
|
93 |
+
|
94 |
+
plt.show()
|
95 |
+
plt.close()
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
# 设置设备
|
100 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
101 |
+
print(f"Using device: {device}")
|
102 |
+
|
103 |
+
model, train_loader, val_loader, device = prepare_data_and_model(model_path='models/resnet50_99.92.pth', model_type='resnet50', batch_size=128)
|
104 |
+
|
105 |
+
## 加载特征
|
106 |
+
print("Collecting features of training set...")
|
107 |
+
features, labels = collect_features(
|
108 |
+
model=model,
|
109 |
+
loader=train_loader,
|
110 |
+
device=device
|
111 |
+
)
|
112 |
+
|
113 |
+
print("Collecting features of validation set...")
|
114 |
+
val_features, val_labels = collect_features(
|
115 |
+
model=model,
|
116 |
+
loader=val_loader,
|
117 |
+
device=device
|
118 |
+
)
|
119 |
+
# 可视化特征
|
120 |
+
# print("Visualizing features using t-SNE...")
|
121 |
+
# visualize_features(
|
122 |
+
# features=features,
|
123 |
+
# labels=labels,
|
124 |
+
# method='tsne',
|
125 |
+
# save_path='outputs/tsne_features.png'
|
126 |
+
# )
|
127 |
+
|
128 |
+
print("Visualizing features using UMAP...")
|
129 |
+
visualize_features(
|
130 |
+
features=features,
|
131 |
+
labels=labels,
|
132 |
+
method='umap',
|
133 |
+
save_path='outputs/resnet50_train_FeatureMap.png',
|
134 |
+
include_unknown=False
|
135 |
+
)
|
136 |
+
|
137 |
+
visualize_features(
|
138 |
+
features=val_features,
|
139 |
+
labels=val_labels,
|
140 |
+
method='umap',
|
141 |
+
save_path='outputs/resnet50_val_FeatureMap.png',
|
142 |
+
include_unknown=True
|
143 |
+
)
|