Facepalm0 commited on
Commit
16e53c1
·
verified ·
1 Parent(s): e95fed9

Upload search.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. search.py +247 -0
search.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from make_env import GridWorldEnv
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import itertools
5
+ import random
6
+ import concurrent.futures
7
+ # np.random.seed(0)
8
+
9
+ class Algorithm_Agent():
10
+ def __init__(self, num_categories, grid_size, grid, probs, loc):
11
+ self.num_categories = num_categories
12
+ self.grid_size = grid_size
13
+ self.grid = grid
14
+ self.probs = probs
15
+ self.loc = loc
16
+ self.current_loc = [loc[0], loc[1]]
17
+ self.path, self.path_category = self.arrange_points()
18
+ self.actions = self.plan_action()
19
+
20
+ def calculate_length(self, paths, elim_paths, prob_paths):
21
+ # 计算路径长度, 输入:所有路径paths=np.array((N, L, 2)),所有消除路径elim_paths=np.array((N, L))
22
+ lengths = np.sum(np.abs(np.array(paths[:, :-1]) - np.array(paths[:, 1:])), axis=-1) +1 # lengths=np.array((N, L-1))
23
+ motion_length = np.sum(lengths, axis=-1) + np.sum(np.abs(self.loc - paths[:, 0]), axis=-1) +1 # motion_length=np.array((N,))
24
+ cum_lengths = np.flip(np.cumsum(np.flip(lengths), axis=-1)) / 14.4 # cum_lengths=np.array((N, L-1))
25
+ # cum_lengths = np.cumsum(lengths, axis=-1)[:, ::-1] / 14.4 # cum_lengths=np.array((N, L-1))
26
+ load_length = np.sum(cum_lengths, axis=-1) - 4 * np.sum(np.array(cum_lengths) * np.array(elim_paths[:, :-1]), axis=-1) # elim_paths的最后一项不参与计算
27
+ prob_length = np.sum(np.arange(len(prob_paths[-1])) * prob_paths)# 用于提升鲁棒性,对较早收集置信度低的网格进行惩罚
28
+
29
+ return motion_length + load_length + 0.0 * prob_length
30
+
31
+
32
+ def get_elim_path(self, category_paths):
33
+ # 获取消除路径,输入:所有路径category_paths=np.array((N, L))
34
+ elim_path = np.zeros_like(category_paths)
35
+ for i in range(category_paths.shape[1]):
36
+ if i > 0:
37
+ previous_caterogy_path = category_paths[:, :i]
38
+ # 统计previous_caterogy_path中,与category_paths[i]同一类别的元素的个数
39
+ same_category_count = np.sum(previous_caterogy_path == category_paths[:, i:i+1], axis=-1)
40
+ elim_path[:, i] = (same_category_count + 1) % 4 == 0
41
+
42
+ return elim_path
43
+
44
+
45
+ def find_shortest_path(self, points):
46
+ min_path = None
47
+ min_length = float('inf')
48
+ for perm in itertools.permutations(points): # Try all permutations
49
+ length = sum(np.sum(np.abs(np.array(perm[i]) - np.array(perm[i + 1]))) for i in range(len(perm) - 1))
50
+ if length < min_length:
51
+ min_length = length
52
+ min_path = list(perm)
53
+ return min_path, min_length
54
+
55
+ def insert_point(self, path, category_path, prob_path, point, category, prob):
56
+ min_length = float('inf')
57
+ best_position = range(len(path) + 1)
58
+ # 将point插入到path的各个位置,合并为一个矩阵np.array((N, L, 2)),L为path的长度
59
+ new_path = np.zeros((len(best_position), len(path) + 1, 2))
60
+ new_category_path = np.zeros((len(best_position), len(path) + 1))
61
+ new_prob_path = np.zeros((len(best_position), len(path) + 1))
62
+ for i in range(len(best_position)):
63
+ new_path[i] = np.insert(path, best_position[i], point, axis=0)
64
+ new_category_path[i] = np.insert(category_path, best_position[i], category, axis=0)
65
+ new_prob_path[i] = np.insert(prob_path, best_position[i], prob, axis=0)
66
+ new_elim_path = self.get_elim_path(new_category_path) # 获取消除路径
67
+ # 计算路径长度
68
+ lengths = self.calculate_length(new_path, new_elim_path, new_prob_path)
69
+ min_length = np.min(lengths)
70
+ best_position = np.argmin(lengths)
71
+
72
+ return best_position, min_length
73
+
74
+ def arrange_points(self):
75
+ points_by_category = {i: [] for i in random.sample(range(self.num_categories), self.num_categories)} # 将所有点按类别分组
76
+ for x in range(self.grid_size[0]):
77
+ for y in range(self.grid_size[1]):
78
+ category = self.grid[x, y]
79
+ if category != -1:
80
+ points_by_category[category].append([x, y])
81
+
82
+ path, category_path, prob_path, rewards_his = [], [], [], []
83
+ for category, points in points_by_category.items(): # 第一轮排列,按类别处理
84
+ while points:
85
+ if len(points) >= 4:
86
+ subset = points[:4]
87
+ points = points[4:]
88
+ else:
89
+ subset = points
90
+ points = []
91
+ if len(path) == 0:
92
+ path, _ = self.find_shortest_path(subset)
93
+ category_path = [category] * len(path)
94
+ prob_path = [self.probs[point[0], point[1]] for point in path]
95
+ else:
96
+ for point in subset:
97
+ position, length = self.insert_point(path, category_path, prob_path, point, category, self.probs[point[0], point[1]])
98
+ path.insert(position, point)
99
+ category_path.insert(position, category)
100
+ prob_path.insert(position, self.probs[point[0], point[1]])
101
+
102
+ # 排列好第一轮后,再次调整顺序
103
+ # 从序列中随机剔除一个元素,然后插入到其他位置,使得路径长度最短
104
+ for i in range(1000):
105
+ index = np.random.randint(0, 144)
106
+ point = path.pop(index)
107
+ category = category_path.pop(index)
108
+ prob = prob_path.pop(index)
109
+ position, length = self.insert_point(path, category_path, prob_path, point, category, prob)
110
+ path.insert(position, point)
111
+ category_path.insert(position, category)
112
+ prob_path.insert(position, prob)
113
+ rewards_his.append(100 + 36 - length / 10)
114
+ self.cumulated_reward = rewards_his[-1]
115
+ # plt.plot(rewards_his)
116
+ # plt.show()
117
+ return path, category_path
118
+
119
+ def plan_action(self):
120
+ actions = []
121
+ for i in range(len(self.path)):
122
+ while self.current_loc[0] != self.path[i][0] or self.current_loc[1] != self.path[i][1]:
123
+ if self.current_loc[0] < self.path[i][0]:
124
+ actions.append(0)
125
+ self.current_loc = [self.current_loc[0] + 1, self.current_loc[1]]
126
+ elif self.current_loc[1] < self.path[i][1]:
127
+ actions.append(1)
128
+ self.current_loc = [self.current_loc[0], self.current_loc[1] + 1]
129
+ elif self.current_loc[0] > self.path[i][0]:
130
+ actions.append(2)
131
+ self.current_loc = [self.current_loc[0] - 1, self.current_loc[1]]
132
+ else:
133
+ actions.append(3)
134
+ self.current_loc = [self.current_loc[0], self.current_loc[1] - 1]
135
+ actions.append(4)
136
+ # print(f'actions: {actions}\n')
137
+ return actions
138
+
139
+ def adjust_grid(predictions, openmax_probs):
140
+
141
+ # 统计每个类别的数量
142
+ class_counts = np.bincount(predictions, minlength=21)
143
+
144
+ # 处理数量为1的类别
145
+ for category in range(20):
146
+ if class_counts[category] % 4 == 1:
147
+ # 找出该类别的样本
148
+ category_indices = np.where(predictions == category)[0]
149
+ if len(category_indices) == 0:
150
+ continue
151
+
152
+ # 在该类别中找出概率最小的样本
153
+ category_probs = openmax_probs[category_indices, category]
154
+ worst_idx = category_indices[np.argmin(category_probs)]
155
+
156
+ # 找出该样本在其他类别中概率最大的类别
157
+ other_probs = openmax_probs[worst_idx]
158
+ other_probs[category] = -1 # 排除当前类别
159
+ new_category = np.argmax(other_probs)
160
+
161
+ # 更新计数
162
+ class_counts[category] -= 1
163
+ class_counts[new_category] += 1
164
+ # 将其转换为新类别
165
+ predictions[worst_idx] = new_category
166
+
167
+ # 处理数量为2的类别
168
+ for category in range(20):
169
+ if class_counts[category] % 4 == 2:
170
+ # 找出所有不属于当前类别的样本索引
171
+ for j in range(2):
172
+ other_indices = np.where(predictions != category)[0]
173
+ if len(other_indices) == 0:
174
+ continue
175
+
176
+ # 在其他所有样本中找出对当前类别概率最高的样本
177
+ category_probs = openmax_probs[other_indices, category]
178
+ best_idx = other_indices[np.argmax(category_probs)]
179
+
180
+ # 更新计数
181
+ class_counts[predictions[best_idx]] -= 1
182
+ class_counts[category] += 1
183
+ # 将其转换为当前类别
184
+ predictions[best_idx] = category
185
+
186
+ # 处理数量为3的类别
187
+ for category in range(20):
188
+ if class_counts[category] % 4 == 3:
189
+ # 找出所有不属于当前类别的样本索引
190
+ other_indices = np.where(predictions != category)[0]
191
+ if len(other_indices) == 0:
192
+ continue
193
+
194
+ # 在其他所有样本中找出对当前类别概率最高的样本
195
+ category_probs = openmax_probs[other_indices, category]
196
+ best_idx = other_indices[np.argmax(category_probs)]
197
+
198
+ # 更新计数
199
+ class_counts[predictions[best_idx]] -= 1
200
+ class_counts[category] += 1
201
+ # 将其转换为当前类别
202
+ predictions[best_idx] = category
203
+
204
+ probs = openmax_probs[np.arange(144), predictions]
205
+
206
+ return predictions.reshape(12, 12), probs.reshape(12, 12)
207
+
208
+ def search_once(grid, probs, loc):
209
+ agent = Algorithm_Agent(21, (12, 12), grid, probs, loc)
210
+ return agent.actions, agent.cumulated_reward
211
+
212
+ # 使用 ProcessPoolExecutor 并行运行 40 个 search_once 函数
213
+ def search(grid, probs, loc, num_iterations=60):
214
+ with concurrent.futures.ProcessPoolExecutor(max_workers=num_iterations) as executor:
215
+ futures = [executor.submit(search_once, grid.copy(), probs.copy(), loc.copy()) for _ in range(num_iterations)]
216
+ results = [future.result() for future in concurrent.futures.as_completed(futures)]
217
+
218
+ # 选择最优的结果
219
+ # for i, result in enumerate(results):
220
+ # if i % 5 == 0:
221
+ # print(f"Iteration {i}: {result[1]}")
222
+ optim_actions, optim_reward = max(results, key=lambda x: x[1])
223
+
224
+ # 在env中测试optim_actions
225
+ env = GridWorldEnv()
226
+ cumulated_reward = 0
227
+ env.reset()
228
+ env.grid, env.loc = grid.copy(), loc.copy()
229
+ for action in optim_actions:
230
+ obs, reward, done, truncated, info = env.step(action)
231
+ cumulated_reward += reward
232
+ print(f'Final reward: {cumulated_reward}')
233
+
234
+ return optim_actions
235
+
236
+
237
+ if __name__ == "__main__":
238
+ for _ in range(1):
239
+ test_env = GridWorldEnv()
240
+ test_env.reset()
241
+ grid, loc = test_env.grid.copy(), test_env.loc.copy()
242
+ pred_grid, pred_loc = test_env.grid.copy(), test_env.loc.copy()
243
+ loc_1, loc_2, loc_3, loc_4, loc_5 = random.sample(range(12), 2), random.sample(range(12), 2), random.sample(range(12), 2), random.sample(range(12), 2), random.sample(range(12), 2)
244
+ a, b, c, d, e = pred_grid[loc_1[0], loc_1[1]], pred_grid[loc_2[0], loc_2[1]], pred_grid[loc_3[0], loc_3[1]], pred_grid[loc_4[0], loc_4[1]], pred_grid[loc_5[0], loc_5[1]]
245
+ pred_grid[loc_1[0], loc_1[1]], pred_grid[loc_2[0], loc_2[1]], pred_grid[loc_3[0], loc_3[1]], pred_grid[loc_4[0], loc_4[1]], pred_grid[loc_5[0], loc_5[1]] = b, e, a, c, d
246
+ search(grid, loc, grid, loc) # 使用5格混淆的grid进行搜索
247
+