Facepalm0 commited on
Commit
63eca6c
·
verified ·
1 Parent(s): 22f4d72

Upload visualize_features.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. visualize_features.py +143 -0
visualize_features.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ from sklearn.manifold import TSNE
4
+ import seaborn as sns
5
+ import umap
6
+ import os
7
+ from post_train import collect_features, prepare_data_and_model
8
+
9
+ def visualize_features(features, labels, method='tsne', save_path=None, include_unknown=False):
10
+ """
11
+ 可视化特征分布
12
+
13
+ Args:
14
+ features: torch.Tensor, 特征向量
15
+ labels: torch.Tensor, 标签
16
+ method: str, 'tsne' 或 'umap'
17
+ save_path: str, 保存路径,如果为None则显示图像
18
+ include_unknown: bool, 是否包含未知类(第21类)
19
+ """
20
+ # 转换为numpy数组
21
+ features = features.cpu().numpy()
22
+ labels = labels.cpu().numpy()
23
+
24
+ # 降维
25
+ print(f"Performing {method.upper()} dimensionality reduction...")
26
+ if method.lower() == 'tsne':
27
+ reducer = TSNE(n_components=2, random_state=42)
28
+ embedded = reducer.fit_transform(features)
29
+ else: # umap
30
+ reducer = umap.UMAP(n_components=2, random_state=42)
31
+ embedded = reducer.fit_transform(features)
32
+
33
+ # 清理之前的图像状态并创建新图形
34
+ plt.close('all') # 关闭所有图形
35
+ fig = plt.figure(figsize=(15, 10))
36
+
37
+ # 定义标记样式和颜色
38
+ markers = ['o', 's', '^', 'D'] # 圆形、方形、三角形、菱形
39
+ colors = ['#FF4B4B', '#4B4BFF', '#4BFF4B', '#FFB74B', '#B74BFF'] # 红、蓝、绿、橙、紫
40
+
41
+ # 确定要绘制的类别数量
42
+ num_classes = 21 if include_unknown else 20
43
+
44
+ # 为每个类别分配标记和颜色
45
+ for i in range(num_classes):
46
+ marker_idx = i % len(markers)
47
+ color_idx = i % len(colors)
48
+
49
+ mask = labels == i
50
+ if i == 20: # 未知类使用特殊标记
51
+ plt.scatter(
52
+ embedded[mask, 0],
53
+ embedded[mask, 1],
54
+ c='gray', # 使用灰色
55
+ marker='*', # 使用星形
56
+ s=150, # 稍微大一点
57
+ alpha=0.6,
58
+ label='Unknown',
59
+ edgecolors='white',
60
+ linewidth=0.5
61
+ )
62
+ else:
63
+ plt.scatter(
64
+ embedded[mask, 0],
65
+ embedded[mask, 1],
66
+ c=colors[color_idx],
67
+ marker=markers[marker_idx],
68
+ s=100,
69
+ alpha=0.6,
70
+ label=f'Class {i}',
71
+ edgecolors='white',
72
+ linewidth=0.5
73
+ )
74
+
75
+ plt.grid(True, linestyle='--', alpha=0.7)
76
+ plt.title(f'Feature Distribution ({method.upper()})', fontsize=14, pad=20)
77
+
78
+ # 调整图例
79
+ plt.legend(bbox_to_anchor=(1.05, 1),
80
+ loc='upper left',
81
+ borderaxespad=0,
82
+ ncol=1, # 使用单列显示图例
83
+ fontsize=10)
84
+
85
+ # 调整布局
86
+ plt.tight_layout()
87
+
88
+ # 先保存再显示
89
+ if save_path:
90
+ print(f"Saving plot to {save_path}")
91
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
92
+ plt.savefig(save_path, bbox_inches='tight', dpi=300, pad_inches=0.5)
93
+
94
+ plt.show()
95
+ plt.close()
96
+
97
+
98
+ if __name__ == "__main__":
99
+ # 设置设备
100
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
+ print(f"Using device: {device}")
102
+
103
+ model, train_loader, val_loader, device = prepare_data_and_model(model_path='models/resnet50_99.92.pth', model_type='resnet50', batch_size=128)
104
+
105
+ ## 加载特征
106
+ print("Collecting features of training set...")
107
+ features, labels = collect_features(
108
+ model=model,
109
+ loader=train_loader,
110
+ device=device
111
+ )
112
+
113
+ print("Collecting features of validation set...")
114
+ val_features, val_labels = collect_features(
115
+ model=model,
116
+ loader=val_loader,
117
+ device=device
118
+ )
119
+ # 可视化特征
120
+ # print("Visualizing features using t-SNE...")
121
+ # visualize_features(
122
+ # features=features,
123
+ # labels=labels,
124
+ # method='tsne',
125
+ # save_path='outputs/tsne_features.png'
126
+ # )
127
+
128
+ print("Visualizing features using UMAP...")
129
+ visualize_features(
130
+ features=features,
131
+ labels=labels,
132
+ method='umap',
133
+ save_path='outputs/resnet50_train_FeatureMap.png',
134
+ include_unknown=False
135
+ )
136
+
137
+ visualize_features(
138
+ features=val_features,
139
+ labels=val_labels,
140
+ method='umap',
141
+ save_path='outputs/resnet50_val_FeatureMap.png',
142
+ include_unknown=True
143
+ )