|
import numpy as np |
|
import pytest |
|
import torch |
|
|
|
from lzero.mcts.buffer.game_segment import GameSegment |
|
from lzero.mcts.utils import prepare_observation |
|
from lzero.policy import select_action |
|
|
|
|
|
args = ["MuZero"] |
|
|
|
|
|
@pytest.mark.unittest |
|
@pytest.mark.parametrize('test_algo', args) |
|
def test_game_segment(test_algo): |
|
|
|
if test_algo == 'EfficientZero': |
|
from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree |
|
from lzero.model.efficientzero_model import EfficientZeroModel as Model |
|
from lzero.mcts.tests.config.atari_efficientzero_config_for_test import atari_efficientzero_config as config |
|
from zoo.atari.envs.atari_lightzero_env import AtariLightZeroEnv |
|
envs = [AtariLightZeroEnv(config.env) for _ in range(config.env.evaluator_env_num)] |
|
|
|
elif test_algo == 'MuZero': |
|
from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree |
|
from lzero.model.muzero_model import MuZeroModel as Model |
|
from lzero.mcts.tests.config.tictactoe_muzero_bot_mode_config_for_test import tictactoe_muzero_config as config |
|
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv |
|
envs = [TicTacToeEnv(config.env) for _ in range(config.env.evaluator_env_num)] |
|
|
|
|
|
model = Model(**config.policy.model) |
|
if config.policy.cuda and torch.cuda.is_available(): |
|
config.policy.device = 'cuda' |
|
else: |
|
config.policy.device = 'cpu' |
|
model.to(config.policy.device) |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
|
|
init_observations = [env.reset() for env in envs] |
|
dones = np.array([False for _ in range(config.env.evaluator_env_num)]) |
|
game_segments = [ |
|
GameSegment( |
|
envs[i].action_space, game_segment_length=config.policy.game_segment_length, config=config.policy |
|
) for i in range(config.env.evaluator_env_num) |
|
] |
|
for i in range(config.env.evaluator_env_num): |
|
game_segments[i].reset( |
|
[init_observations[i]['observation'] for _ in range(config.policy.model.frame_stack_num)] |
|
) |
|
episode_rewards = np.zeros(config.env.evaluator_env_num) |
|
|
|
while not dones.all(): |
|
stack_obs = [game_segment.get_obs() for game_segment in game_segments] |
|
stack_obs = prepare_observation(stack_obs, config.policy.model.model_type) |
|
stack_obs = torch.from_numpy(np.array(stack_obs)).to(config.policy.device) |
|
|
|
|
|
|
|
|
|
network_output = model.initial_inference(stack_obs) |
|
|
|
|
|
policy_logits_pool = network_output.policy_logits.detach().cpu().numpy().tolist() |
|
latent_state_roots = network_output.latent_state.detach().cpu().numpy() |
|
|
|
if test_algo == 'EfficientZero': |
|
reward_hidden_state_roots = network_output.reward_hidden_state |
|
value_prefix_pool = network_output.value_prefix |
|
reward_hidden_state_roots = ( |
|
reward_hidden_state_roots[0].detach().cpu().numpy(), |
|
reward_hidden_state_roots[1].detach().cpu().numpy() |
|
) |
|
|
|
legal_actions_list = [ |
|
[i for i in range(config.policy.model.action_space_size)] |
|
for _ in range(config.env.evaluator_env_num) |
|
] |
|
elif test_algo == 'MuZero': |
|
reward_pool = network_output.reward |
|
|
|
legal_actions_list = [ |
|
[a for a, x in enumerate(init_observations[i]['action_mask']) if x == 1] |
|
for i in range(config.env.evaluator_env_num) |
|
] |
|
|
|
|
|
to_play = [-1 for _ in range(config.env.evaluator_env_num)] |
|
|
|
if test_algo == 'EfficientZero': |
|
roots = MCTSCtree.roots(config.env.evaluator_env_num, legal_actions_list) |
|
roots.prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play) |
|
MCTSCtree(config.policy).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) |
|
|
|
elif test_algo == 'MuZero': |
|
roots = MCTSCtree.roots(config.env.evaluator_env_num, legal_actions_list) |
|
roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) |
|
MCTSCtree(config.policy).search(roots, model, latent_state_roots, to_play) |
|
|
|
roots_distributions = roots.get_distributions() |
|
roots_values = roots.get_values() |
|
|
|
for i in range(config.env.evaluator_env_num): |
|
distributions, value, env = roots_distributions[i], roots_values[i], envs[i] |
|
|
|
action, _ = select_action(distributions, temperature=1, deterministic=True) |
|
|
|
|
|
|
|
obs, reward, done, info = env.step(action) |
|
obs = obs['observation'] |
|
|
|
game_segments[i].store_search_stats(distributions, value) |
|
game_segments[i].append(action, obs, reward) |
|
|
|
dones[i] = done |
|
episode_rewards[i] += reward |
|
if dones[i]: |
|
continue |
|
|
|
for env in envs: |
|
env.close() |
|
|