Facepalm0 commited on
Commit
0b3548a
·
verified ·
1 Parent(s): 433a583

Upload client.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. client.py +155 -0
client.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socketio
2
+ import time
3
+ import numpy as np
4
+ import os
5
+ import socketio.exceptions
6
+ from models.resnet import ImageClassifier
7
+ from search import search, adjust_grid
8
+
9
+ actions = []
10
+ def action_policy(grid=None, loc=None, rounds=0):
11
+ """
12
+ Args:
13
+ grid: 当前网格状态
14
+ loc: 当前位置
15
+ rounds: 当前回合数
16
+ Returns:
17
+ action: 下一步动作
18
+ """
19
+ # 第0回合时计算整个行动序列
20
+
21
+ return None
22
+
23
+ def recognition(img):
24
+ """
25
+ Args:
26
+ img: shape [600,600,3] 的list,RGB格式
27
+ Returns:
28
+ grid: (12,12) 的numpy数组
29
+ """
30
+ if not hasattr(recognition, 'classifier'):
31
+ recognition.classifier = ImageClassifier(model_type='resnet18', model_path='models/best_model_99.92_02.pth', openmax_path='models/best_openmax_95.62_02.pth', multiplier=0.6)
32
+
33
+ # 先转换为numpy数组
34
+ img = np.array(img, dtype=np.uint8)
35
+ # 将图像分割成网格
36
+ patches = []
37
+ tile_size = 50
38
+
39
+ for i in range(12):
40
+ for j in range(12):
41
+ patch = img[i*tile_size:(i+1)*tile_size, j*tile_size:(j+1)*tile_size]
42
+ patches.append(patch)
43
+
44
+ patches = np.array(patches)
45
+ # 获取预测结果
46
+ # predictions = recognition.classifier.resnet_predict(patches)
47
+ predictions, openmax_probs = recognition.classifier.predict(patches)
48
+ # 重塑为12x12网格
49
+ grid, probs = adjust_grid(predictions.cpu().numpy(), openmax_probs.cpu().numpy())
50
+
51
+ return grid, probs
52
+
53
+
54
+ def team_play_game(team_id, game_type, game_data_id, ip, port):
55
+ sio = socketio.Client(request_timeout=60)
56
+ grid = None
57
+ begin = game_type + game_data_id
58
+ @sio.event
59
+ def connect():
60
+ print(f"Connected to server, game_type: {game_type}, game data id: {begin}")
61
+ pass
62
+ @sio.event
63
+ def disconnect():
64
+ # print(f"End game {begin}, disconnected from server")
65
+ pass
66
+ @sio.event
67
+ def connect_error(data):
68
+ print('Connect error', data)
69
+ @sio.event
70
+ def disconnect_error(data):
71
+ print('Disconnect error', data)
72
+ @sio.event
73
+ def response(data):
74
+ nonlocal grid
75
+ if 'error' in data:
76
+ print(data['error'])
77
+ sio.disconnect()
78
+ else:
79
+ try:
80
+ if data['rounds'] == 0:
81
+ print(f"Team {data['team_id']} begin game {data['game_id']}")
82
+ is_end = data.get('is_end', False)
83
+ score = data['score']
84
+ bag = data['bag']
85
+ loc = data['loc']
86
+ game_id = data['game_id']
87
+ os.makedirs(f'./{data["team_id"]}/', exist_ok=True)
88
+ send_data = {'team_id': data['team_id'], 'game_id': game_id}
89
+ if data['rounds']==0:
90
+ if (game_type == 'a'):
91
+ grid, probs= np.array(data['grid'], dtype=int), np.ones((12, 12))
92
+ if (game_type == '2'):
93
+ grid, probs = recognition(data['img'])
94
+ print('Recognition Finished!')
95
+ send_data['grid_pred'] = grid.tolist()
96
+
97
+ # 使用 search 函数替代直接使用 Algorithm_Agent
98
+ actions[:] = search(grid=grid, probs=probs, loc=loc)[:]
99
+
100
+ score_npy = f'./{data["team_id"]}/{data["game_id"]}_score.npy'
101
+ if os.path.exists(score_npy):
102
+ prev_score = np.load(score_npy)
103
+ else:
104
+ prev_score = np.array(0.0)
105
+ np.save(score_npy, prev_score + score)
106
+ if is_end:
107
+ print(f"Team {data['team_id']} end game {data['game_id']}, cum_score: {prev_score + score:.2f}")
108
+ if game_type == '2':
109
+ print(f'Recognition acc on this game fig: {data["acc"]}')
110
+ print(f'time penalty:{data["time penalty"]}')
111
+ sio.disconnect()
112
+ else:
113
+ action = actions.pop(0)
114
+ if action == 4:
115
+ grid[loc[0], loc[1]] = -1
116
+ send_data['action'] = action
117
+ if sio.connected:
118
+ sio.emit('continue', send_data)
119
+ else:
120
+ print('sio not connected')
121
+ except Exception as e:
122
+ print(f'{e}')
123
+ sio.disconnect()
124
+ try:
125
+ # 连接到服务器
126
+ sio.connect(f'http://{ip}:{port}/', wait_timeout=30)
127
+ # 发送消息到服务器
128
+ message = {'team_id': team_id, 'begin': begin}
129
+ sio.emit('begin', message)
130
+ sio.wait()
131
+ except socketio.exceptions.ConnectionError as e:
132
+ print('Connection Error')
133
+ sio.disconnect()
134
+ except Exception as e:
135
+ print(f'Exception: {e}')
136
+ sio.disconnect()
137
+ finally:
138
+ # print('end team play game')
139
+ pass
140
+
141
+
142
+ if __name__ == '__main__':
143
+ team_id = f'ewv9ssdcuvg6'
144
+ ip = '69.230.243.237'
145
+ port = '8086'
146
+ # game_type must be in ['2', 'a'], '2' for full game and recognition only, 'a' for action_only
147
+ game_type = '2'
148
+
149
+ # 初赛的第1阶段,game_data_id must be in ['00000', '00001', ..., '00099']
150
+ # 初赛的终榜阶段,game_data_id must be in ['00000', '00001', ..., '00199']
151
+ game_data_id = [f'{i:05}' for i in range(2,100)]
152
+ st = time.time()
153
+ for gdi in game_data_id:
154
+ team_play_game(team_id, game_type, gdi, ip, port)
155
+ print(f'Total time: {(time.time()-st):.1f}s')