File size: 5,717 Bytes
0b3548a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import socketio
import time
import numpy as np
import os
import socketio.exceptions
from models.resnet import ImageClassifier
from search import search, adjust_grid

actions = []
def action_policy(grid=None, loc=None, rounds=0):
    """

    Args:

        grid: 当前网格状态

        loc: 当前位置

        rounds: 当前回合数

    Returns:

        action: 下一步动作

    """
    # 第0回合时计算整个行动序列
    
    return None

def recognition(img):
    """

    Args:

        img: shape [600,600,3] 的list,RGB格式

    Returns:

        grid: (12,12) 的numpy数组

    """
    if not hasattr(recognition, 'classifier'):
        recognition.classifier = ImageClassifier(model_type='resnet18', model_path='models/best_model_99.92_02.pth', openmax_path='models/best_openmax_95.62_02.pth', multiplier=0.6)
    
    # 先转换为numpy数组
    img = np.array(img, dtype=np.uint8)
    # 将图像分割成网格
    patches = []
    tile_size = 50
    
    for i in range(12):
        for j in range(12):
            patch = img[i*tile_size:(i+1)*tile_size, j*tile_size:(j+1)*tile_size]
            patches.append(patch)
    
    patches = np.array(patches)
    # 获取预测结果
    # predictions = recognition.classifier.resnet_predict(patches)
    predictions, openmax_probs = recognition.classifier.predict(patches)
    # 重塑为12x12网格
    grid, probs = adjust_grid(predictions.cpu().numpy(), openmax_probs.cpu().numpy())

    return grid, probs


def team_play_game(team_id, game_type, game_data_id, ip, port):
    sio = socketio.Client(request_timeout=60)
    grid = None
    begin = game_type + game_data_id
    @sio.event
    def connect():
        print(f"Connected to server, game_type: {game_type}, game data id: {begin}")
        pass
    @sio.event
    def disconnect():
        # print(f"End game {begin}, disconnected from server")
        pass
    @sio.event
    def connect_error(data):
        print('Connect error', data)
    @sio.event
    def disconnect_error(data):
        print('Disconnect error', data)
    @sio.event
    def response(data):
        nonlocal grid
        if 'error' in data:
            print(data['error'])
            sio.disconnect()
        else:
            try:
                if data['rounds'] == 0:
                    print(f"Team {data['team_id']} begin game {data['game_id']}")
                is_end = data.get('is_end', False)
                score = data['score']
                bag = data['bag']
                loc = data['loc']
                game_id = data['game_id']
                os.makedirs(f'./{data["team_id"]}/', exist_ok=True)
                send_data = {'team_id': data['team_id'], 'game_id': game_id}
                if data['rounds']==0:
                    if (game_type == 'a'):
                        grid, probs= np.array(data['grid'], dtype=int), np.ones((12, 12))
                    if (game_type == '2'):
                        grid, probs = recognition(data['img'])
                        print('Recognition Finished!')
                        send_data['grid_pred'] = grid.tolist()
                        
                    # 使用 search 函数替代直接使用 Algorithm_Agent
                    actions[:] = search(grid=grid, probs=probs, loc=loc)[:]
                
                score_npy = f'./{data["team_id"]}/{data["game_id"]}_score.npy'
                if os.path.exists(score_npy):
                    prev_score = np.load(score_npy)
                else:
                    prev_score = np.array(0.0)
                np.save(score_npy, prev_score + score)
                if is_end:
                    print(f"Team {data['team_id']} end game {data['game_id']}, cum_score: {prev_score + score:.2f}")
                    if game_type == '2':
                        print(f'Recognition acc on this game fig: {data["acc"]}')
                    print(f'time penalty:{data["time penalty"]}')
                    sio.disconnect()
                else:
                    action = actions.pop(0)
                    if action == 4:
                        grid[loc[0], loc[1]] = -1
                    send_data['action'] = action
                    if sio.connected:
                        sio.emit('continue', send_data)
                    else:
                        print('sio not connected')
            except Exception as e:
                print(f'{e}')
                sio.disconnect()
    try:
        # 连接到服务器
        sio.connect(f'http://{ip}:{port}/', wait_timeout=30)
        # 发送消息到服务器
        message = {'team_id': team_id, 'begin': begin}
        sio.emit('begin', message)
        sio.wait()
    except socketio.exceptions.ConnectionError as e:
        print('Connection Error')
        sio.disconnect()
    except Exception as e:
        print(f'Exception: {e}')
        sio.disconnect()
    finally:
        # print('end team play game')
        pass


if __name__ == '__main__':
    team_id = f'ewv9ssdcuvg6'
    ip = '69.230.243.237'
    port = '8086'
    # game_type must be in ['2', 'a'], '2' for full game and recognition only, 'a' for action_only
    game_type = '2'
    
    # 初赛的第1阶段,game_data_id  must be in ['00000', '00001', ..., '00099']
    # 初赛的终榜阶段,game_data_id  must be in ['00000', '00001', ..., '00199']
    game_data_id = [f'{i:05}' for i in range(2,100)]
    st = time.time()
    for gdi in game_data_id:
        team_play_game(team_id, game_type, gdi, ip, port)
    print(f'Total time: {(time.time()-st):.1f}s')