Facepalm0's picture
Upload client.py with huggingface_hub
0b3548a verified
raw
history blame
5.72 kB
import socketio
import time
import numpy as np
import os
import socketio.exceptions
from models.resnet import ImageClassifier
from search import search, adjust_grid
actions = []
def action_policy(grid=None, loc=None, rounds=0):
"""
Args:
grid: 当前网格状态
loc: 当前位置
rounds: 当前回合数
Returns:
action: 下一步动作
"""
# 第0回合时计算整个行动序列
return None
def recognition(img):
"""
Args:
img: shape [600,600,3] 的list,RGB格式
Returns:
grid: (12,12) 的numpy数组
"""
if not hasattr(recognition, 'classifier'):
recognition.classifier = ImageClassifier(model_type='resnet18', model_path='models/best_model_99.92_02.pth', openmax_path='models/best_openmax_95.62_02.pth', multiplier=0.6)
# 先转换为numpy数组
img = np.array(img, dtype=np.uint8)
# 将图像分割成网格
patches = []
tile_size = 50
for i in range(12):
for j in range(12):
patch = img[i*tile_size:(i+1)*tile_size, j*tile_size:(j+1)*tile_size]
patches.append(patch)
patches = np.array(patches)
# 获取预测结果
# predictions = recognition.classifier.resnet_predict(patches)
predictions, openmax_probs = recognition.classifier.predict(patches)
# 重塑为12x12网格
grid, probs = adjust_grid(predictions.cpu().numpy(), openmax_probs.cpu().numpy())
return grid, probs
def team_play_game(team_id, game_type, game_data_id, ip, port):
sio = socketio.Client(request_timeout=60)
grid = None
begin = game_type + game_data_id
@sio.event
def connect():
print(f"Connected to server, game_type: {game_type}, game data id: {begin}")
pass
@sio.event
def disconnect():
# print(f"End game {begin}, disconnected from server")
pass
@sio.event
def connect_error(data):
print('Connect error', data)
@sio.event
def disconnect_error(data):
print('Disconnect error', data)
@sio.event
def response(data):
nonlocal grid
if 'error' in data:
print(data['error'])
sio.disconnect()
else:
try:
if data['rounds'] == 0:
print(f"Team {data['team_id']} begin game {data['game_id']}")
is_end = data.get('is_end', False)
score = data['score']
bag = data['bag']
loc = data['loc']
game_id = data['game_id']
os.makedirs(f'./{data["team_id"]}/', exist_ok=True)
send_data = {'team_id': data['team_id'], 'game_id': game_id}
if data['rounds']==0:
if (game_type == 'a'):
grid, probs= np.array(data['grid'], dtype=int), np.ones((12, 12))
if (game_type == '2'):
grid, probs = recognition(data['img'])
print('Recognition Finished!')
send_data['grid_pred'] = grid.tolist()
# 使用 search 函数替代直接使用 Algorithm_Agent
actions[:] = search(grid=grid, probs=probs, loc=loc)[:]
score_npy = f'./{data["team_id"]}/{data["game_id"]}_score.npy'
if os.path.exists(score_npy):
prev_score = np.load(score_npy)
else:
prev_score = np.array(0.0)
np.save(score_npy, prev_score + score)
if is_end:
print(f"Team {data['team_id']} end game {data['game_id']}, cum_score: {prev_score + score:.2f}")
if game_type == '2':
print(f'Recognition acc on this game fig: {data["acc"]}')
print(f'time penalty:{data["time penalty"]}')
sio.disconnect()
else:
action = actions.pop(0)
if action == 4:
grid[loc[0], loc[1]] = -1
send_data['action'] = action
if sio.connected:
sio.emit('continue', send_data)
else:
print('sio not connected')
except Exception as e:
print(f'{e}')
sio.disconnect()
try:
# 连接到服务器
sio.connect(f'http://{ip}:{port}/', wait_timeout=30)
# 发送消息到服务器
message = {'team_id': team_id, 'begin': begin}
sio.emit('begin', message)
sio.wait()
except socketio.exceptions.ConnectionError as e:
print('Connection Error')
sio.disconnect()
except Exception as e:
print(f'Exception: {e}')
sio.disconnect()
finally:
# print('end team play game')
pass
if __name__ == '__main__':
team_id = f'ewv9ssdcuvg6'
ip = '69.230.243.237'
port = '8086'
# game_type must be in ['2', 'a'], '2' for full game and recognition only, 'a' for action_only
game_type = '2'
# 初赛的第1阶段,game_data_id must be in ['00000', '00001', ..., '00099']
# 初赛的终榜阶段,game_data_id must be in ['00000', '00001', ..., '00199']
game_data_id = [f'{i:05}' for i in range(2,100)]
st = time.time()
for gdi in game_data_id:
team_play_game(team_id, game_type, gdi, ip, port)
print(f'Total time: {(time.time()-st):.1f}s')