File size: 5,717 Bytes
0b3548a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import socketio
import time
import numpy as np
import os
import socketio.exceptions
from models.resnet import ImageClassifier
from search import search, adjust_grid
actions = []
def action_policy(grid=None, loc=None, rounds=0):
"""
Args:
grid: 当前网格状态
loc: 当前位置
rounds: 当前回合数
Returns:
action: 下一步动作
"""
# 第0回合时计算整个行动序列
return None
def recognition(img):
"""
Args:
img: shape [600,600,3] 的list,RGB格式
Returns:
grid: (12,12) 的numpy数组
"""
if not hasattr(recognition, 'classifier'):
recognition.classifier = ImageClassifier(model_type='resnet18', model_path='models/best_model_99.92_02.pth', openmax_path='models/best_openmax_95.62_02.pth', multiplier=0.6)
# 先转换为numpy数组
img = np.array(img, dtype=np.uint8)
# 将图像分割成网格
patches = []
tile_size = 50
for i in range(12):
for j in range(12):
patch = img[i*tile_size:(i+1)*tile_size, j*tile_size:(j+1)*tile_size]
patches.append(patch)
patches = np.array(patches)
# 获取预测结果
# predictions = recognition.classifier.resnet_predict(patches)
predictions, openmax_probs = recognition.classifier.predict(patches)
# 重塑为12x12网格
grid, probs = adjust_grid(predictions.cpu().numpy(), openmax_probs.cpu().numpy())
return grid, probs
def team_play_game(team_id, game_type, game_data_id, ip, port):
sio = socketio.Client(request_timeout=60)
grid = None
begin = game_type + game_data_id
@sio.event
def connect():
print(f"Connected to server, game_type: {game_type}, game data id: {begin}")
pass
@sio.event
def disconnect():
# print(f"End game {begin}, disconnected from server")
pass
@sio.event
def connect_error(data):
print('Connect error', data)
@sio.event
def disconnect_error(data):
print('Disconnect error', data)
@sio.event
def response(data):
nonlocal grid
if 'error' in data:
print(data['error'])
sio.disconnect()
else:
try:
if data['rounds'] == 0:
print(f"Team {data['team_id']} begin game {data['game_id']}")
is_end = data.get('is_end', False)
score = data['score']
bag = data['bag']
loc = data['loc']
game_id = data['game_id']
os.makedirs(f'./{data["team_id"]}/', exist_ok=True)
send_data = {'team_id': data['team_id'], 'game_id': game_id}
if data['rounds']==0:
if (game_type == 'a'):
grid, probs= np.array(data['grid'], dtype=int), np.ones((12, 12))
if (game_type == '2'):
grid, probs = recognition(data['img'])
print('Recognition Finished!')
send_data['grid_pred'] = grid.tolist()
# 使用 search 函数替代直接使用 Algorithm_Agent
actions[:] = search(grid=grid, probs=probs, loc=loc)[:]
score_npy = f'./{data["team_id"]}/{data["game_id"]}_score.npy'
if os.path.exists(score_npy):
prev_score = np.load(score_npy)
else:
prev_score = np.array(0.0)
np.save(score_npy, prev_score + score)
if is_end:
print(f"Team {data['team_id']} end game {data['game_id']}, cum_score: {prev_score + score:.2f}")
if game_type == '2':
print(f'Recognition acc on this game fig: {data["acc"]}')
print(f'time penalty:{data["time penalty"]}')
sio.disconnect()
else:
action = actions.pop(0)
if action == 4:
grid[loc[0], loc[1]] = -1
send_data['action'] = action
if sio.connected:
sio.emit('continue', send_data)
else:
print('sio not connected')
except Exception as e:
print(f'{e}')
sio.disconnect()
try:
# 连接到服务器
sio.connect(f'http://{ip}:{port}/', wait_timeout=30)
# 发送消息到服务器
message = {'team_id': team_id, 'begin': begin}
sio.emit('begin', message)
sio.wait()
except socketio.exceptions.ConnectionError as e:
print('Connection Error')
sio.disconnect()
except Exception as e:
print(f'Exception: {e}')
sio.disconnect()
finally:
# print('end team play game')
pass
if __name__ == '__main__':
team_id = f'ewv9ssdcuvg6'
ip = '69.230.243.237'
port = '8086'
# game_type must be in ['2', 'a'], '2' for full game and recognition only, 'a' for action_only
game_type = '2'
# 初赛的第1阶段,game_data_id must be in ['00000', '00001', ..., '00099']
# 初赛的终榜阶段,game_data_id must be in ['00000', '00001', ..., '00199']
game_data_id = [f'{i:05}' for i in range(2,100)]
st = time.time()
for gdi in game_data_id:
team_play_game(team_id, game_type, gdi, ip, port)
print(f'Total time: {(time.time()-st):.1f}s')
|