File size: 3,838 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import sys
sys.path.append("/Users/puyuan/code/LightZero/")
from functools import partial
import torch
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import main_config, create_config
import numpy as np


class Agent:
    def __init__(self, seed=0):

        # model_path = './ckpt/ckpt_best.pth.tar'
        model_path = None

        # If True, you can play with the agent.
        # main_config.env.agent_vs_human = True
        main_config.env.agent_vs_human = False
        # main_config.env.render_mode = 'image_realtime_mode'
        main_config.env.render_mode = 'image_savefile_mode'
        main_config.env.replay_path = './video'

        create_config.env_manager.type = 'base'
        main_config.env.alphazero_mcts_ctree = False
        main_config.policy.mcts_ctree = False
        main_config.env.evaluator_env_num = 1
        main_config.env.n_evaluator_episode = 1

        cfg, create_cfg = [main_config, create_config]
        create_cfg.policy.type = create_cfg.policy.type

        if cfg.policy.cuda and torch.cuda.is_available():
            cfg.policy.device = 'cuda'
        else:
            cfg.policy.device = 'cpu'

        cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
        # Create main components: env, policy
        env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
        collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
        evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
        collector_env.seed(cfg.seed)
        evaluator_env.seed(cfg.seed, dynamic_seed=False)
        set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
        self.policy = create_policy(cfg.policy, model=None, enable_field=['learn', 'collect', 'eval'])

        # load pretrained model
        if model_path is not None:
            self.policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))

    def compute_action(self, obs):
        # print(obs)
        policy_output = self.policy.eval_mode.forward({'0': obs})
        actions = {env_id: output['action'] for env_id, output in policy_output.items()}
        return actions['0']


if __name__ == '__main__':
    from easydict import EasyDict
    from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
    cfg = EasyDict(
        prob_random_agent=0,
        board_size=15,
        battle_mode='self_play_mode',  # NOTE
        channel_last=False,
        scale=False,
        agent_vs_human=False,
        bot_action_type='v1',  # {'v0', 'v1', 'alpha_beta_pruning'}
        prob_random_action_in_bot=0.,
        check_action_to_connect4_in_bot_v0=False,
        render_mode='state_realtime_mode',
        replay_path=None,
        screen_scaling=9,
        alphazero_mcts_ctree=False,
    )
    env = GomokuEnv(cfg)
    obs = env.reset()
    agent = Agent()

    while True:
        # 更新游戏环境
        observation, reward, done, info = env.step(env.random_action())
        # 如果游戏没有结束,获取 bot 的动作
        if not done:
            # agent_action = env.random_action()
            agent_action = agent.compute_action(observation)
            # 更新环境状态
            _, _, done, _ = env.step(agent_action)
            # 准备响应数据
            print('orig bot action: {}'.format(agent_action))
            agent_action = {'i': int(agent_action // 15), 'j': int(agent_action % 15)}
            print('bot action: {}'.format(agent_action))
        else:
            break