|
from typing import Any, List, Tuple, Union, TYPE_CHECKING, Optional |
|
|
|
import numpy as np |
|
import torch |
|
from ding.utils import BUFFER_REGISTRY |
|
|
|
from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree |
|
from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree |
|
from lzero.mcts.utils import prepare_observation |
|
from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform |
|
from .game_buffer import GameBuffer |
|
|
|
if TYPE_CHECKING: |
|
from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy |
|
|
|
|
|
@BUFFER_REGISTRY.register('game_buffer_muzero') |
|
class MuZeroGameBuffer(GameBuffer): |
|
""" |
|
Overview: |
|
The specific game buffer for MuZero policy. |
|
""" |
|
|
|
def __init__(self, cfg: dict): |
|
super().__init__(cfg) |
|
""" |
|
Overview: |
|
Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key |
|
in the default configuration, the user-provided value will override the default configuration. Otherwise, |
|
the default configuration will be used. |
|
""" |
|
default_config = self.default_config() |
|
default_config.update(cfg) |
|
self._cfg = default_config |
|
assert self._cfg.env_type in ['not_board_games', 'board_games'] |
|
assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space'] |
|
self.replay_buffer_size = self._cfg.replay_buffer_size |
|
self.batch_size = self._cfg.batch_size |
|
self._alpha = self._cfg.priority_prob_alpha |
|
self._beta = self._cfg.priority_prob_beta |
|
|
|
self.keep_ratio = 1 |
|
self.model_update_interval = 10 |
|
self.num_of_collected_episodes = 0 |
|
self.base_idx = 0 |
|
self.clear_time = 0 |
|
|
|
self.game_segment_buffer = [] |
|
self.game_pos_priorities = [] |
|
self.game_segment_game_pos_look_up = [] |
|
|
|
def sample( |
|
self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] |
|
) -> List[Any]: |
|
""" |
|
Overview: |
|
sample data from ``GameBuffer`` and prepare the current and target batch for training. |
|
Arguments: |
|
- batch_size (:obj:`int`): batch size. |
|
- policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. |
|
Returns: |
|
- train_data (:obj:`List`): List of train data, including current_batch and target_batch. |
|
""" |
|
policy._target_model.to(self._cfg.device) |
|
policy._target_model.eval() |
|
|
|
|
|
reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( |
|
batch_size, self._cfg.reanalyze_ratio |
|
) |
|
|
|
batch_rewards, batch_target_values = self._compute_target_reward_value( |
|
reward_value_context, policy._target_model |
|
) |
|
|
|
batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) |
|
batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( |
|
policy_non_re_context, self._cfg.model.action_space_size |
|
) |
|
|
|
|
|
if 0 < self._cfg.reanalyze_ratio < 1: |
|
batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) |
|
elif self._cfg.reanalyze_ratio == 1: |
|
batch_target_policies = batch_target_policies_re |
|
elif self._cfg.reanalyze_ratio == 0: |
|
batch_target_policies = batch_target_policies_non_re |
|
|
|
target_batch = [batch_rewards, batch_target_values, batch_target_policies] |
|
|
|
|
|
train_data = [current_batch, target_batch] |
|
return train_data |
|
|
|
def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: |
|
""" |
|
Overview: |
|
first sample orig_data through ``_sample_orig_data()``, |
|
then prepare the context of a batch: |
|
reward_value_context: the context of reanalyzed value targets |
|
policy_re_context: the context of reanalyzed policy targets |
|
policy_non_re_context: the context of non-reanalyzed policy targets |
|
current_batch: the inputs of batch |
|
Arguments: |
|
- batch_size (:obj:`int`): the batch size of orig_data from replay buffer. |
|
- reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) |
|
Returns: |
|
- context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch |
|
""" |
|
|
|
orig_data = self._sample_orig_data(batch_size) |
|
game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data |
|
batch_size = len(batch_index_list) |
|
obs_list, action_list, mask_list = [], [], [] |
|
|
|
for i in range(batch_size): |
|
game = game_segment_list[i] |
|
pos_in_game_segment = pos_in_game_segment_list[i] |
|
|
|
actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + |
|
self._cfg.num_unroll_steps].tolist() |
|
|
|
mask_tmp = [1. for i in range(len(actions_tmp))] |
|
mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] |
|
|
|
|
|
actions_tmp += [ |
|
np.random.randint(0, game.action_space_size) |
|
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) |
|
] |
|
|
|
|
|
|
|
|
|
obs_list.append( |
|
game_segment_list[i].get_unroll_obs( |
|
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True |
|
) |
|
) |
|
action_list.append(actions_tmp) |
|
mask_list.append(mask_tmp) |
|
|
|
|
|
obs_list = prepare_observation(obs_list, self._cfg.model.model_type) |
|
|
|
|
|
current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] |
|
for i in range(len(current_batch)): |
|
current_batch[i] = np.asarray(current_batch[i]) |
|
|
|
total_transitions = self.get_num_of_transitions() |
|
|
|
|
|
reward_value_context = self._prepare_reward_value_context( |
|
batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions |
|
) |
|
""" |
|
only reanalyze recent reanalyze_ratio (e.g. 50%) data |
|
if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps |
|
0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy |
|
""" |
|
reanalyze_num = int(batch_size * reanalyze_ratio) |
|
|
|
if reanalyze_num > 0: |
|
|
|
policy_re_context = self._prepare_policy_reanalyzed_context( |
|
batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num], |
|
pos_in_game_segment_list[:reanalyze_num] |
|
) |
|
else: |
|
policy_re_context = None |
|
|
|
|
|
if reanalyze_num < batch_size: |
|
|
|
policy_non_re_context = self._prepare_policy_non_reanalyzed_context( |
|
batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:], |
|
pos_in_game_segment_list[reanalyze_num:] |
|
) |
|
else: |
|
policy_non_re_context = None |
|
|
|
context = reward_value_context, policy_re_context, policy_non_re_context, current_batch |
|
return context |
|
|
|
def _prepare_reward_value_context( |
|
self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], |
|
total_transitions: int |
|
) -> List[Any]: |
|
""" |
|
Overview: |
|
prepare the context of rewards and values for calculating TD value target in reanalyzing part. |
|
Arguments: |
|
- batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer |
|
- game_segment_list (:obj:`list`): list of game segments |
|
- pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment |
|
- total_transitions (:obj:`int`): number of collected transitions |
|
Returns: |
|
- reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, |
|
td_steps_list, action_mask_segment, to_play_segment |
|
""" |
|
zero_obs = game_segment_list[0].zero_obs() |
|
value_obs_list = [] |
|
|
|
value_mask = [] |
|
rewards_list = [] |
|
game_segment_lens = [] |
|
|
|
action_mask_segment, to_play_segment = [], [] |
|
|
|
td_steps_list = [] |
|
for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): |
|
game_segment_len = len(game_segment) |
|
game_segment_lens.append(game_segment_len) |
|
|
|
td_steps = np.clip(self._cfg.td_steps, 1, max(1, game_segment_len - state_index)).astype(np.int32) |
|
|
|
|
|
|
|
|
|
game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps) |
|
|
|
rewards_list.append(game_segment.reward_segment) |
|
|
|
|
|
action_mask_segment.append(game_segment.action_mask_segment) |
|
to_play_segment.append(game_segment.to_play_segment) |
|
|
|
for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): |
|
|
|
td_steps_list.append(td_steps) |
|
|
|
bootstrap_index = current_index + td_steps |
|
|
|
if bootstrap_index < game_segment_len: |
|
value_mask.append(1) |
|
|
|
beg_index = current_index - state_index |
|
end_index = beg_index + self._cfg.model.frame_stack_num |
|
|
|
obs = game_obs[beg_index:end_index] |
|
else: |
|
value_mask.append(0) |
|
obs = zero_obs |
|
|
|
value_obs_list.append(obs) |
|
|
|
reward_value_context = [ |
|
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, |
|
action_mask_segment, to_play_segment |
|
] |
|
return reward_value_context |
|
|
|
def _prepare_policy_non_reanalyzed_context( |
|
self, batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int] |
|
) -> List[Any]: |
|
""" |
|
Overview: |
|
prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play |
|
Arguments: |
|
- batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer |
|
- game_segment_list (:obj:`list`): list of game segments |
|
- pos_in_game_segment_list (:obj:`list`): list transition index in game |
|
Returns: |
|
- policy_non_re_context (:obj:`list`): pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment |
|
""" |
|
child_visits = [] |
|
game_segment_lens = [] |
|
|
|
action_mask_segment, to_play_segment = [], [] |
|
|
|
for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): |
|
game_segment_len = len(game_segment) |
|
game_segment_lens.append(game_segment_len) |
|
|
|
action_mask_segment.append(game_segment.action_mask_segment) |
|
to_play_segment.append(game_segment.to_play_segment) |
|
|
|
child_visits.append(game_segment.child_visit_segment) |
|
|
|
policy_non_re_context = [ |
|
pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment |
|
] |
|
return policy_non_re_context |
|
|
|
def _prepare_policy_reanalyzed_context( |
|
self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] |
|
) -> List[Any]: |
|
""" |
|
Overview: |
|
prepare the context of policies for calculating policy target in reanalyzing part. |
|
Arguments: |
|
- batch_index_list (:obj:'list'): start transition index in the replay buffer |
|
- game_segment_list (:obj:'list'): list of game segments |
|
- pos_in_game_segment_list (:obj:'list'): position of transition index in one game history |
|
Returns: |
|
- policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices, |
|
child_visits, game_segment_lens, action_mask_segment, to_play_segment |
|
""" |
|
zero_obs = game_segment_list[0].zero_obs() |
|
with torch.no_grad(): |
|
|
|
policy_obs_list = [] |
|
policy_mask = [] |
|
|
|
|
|
rewards, child_visits, game_segment_lens = [], [], [] |
|
|
|
action_mask_segment, to_play_segment = [], [] |
|
for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): |
|
game_segment_len = len(game_segment) |
|
game_segment_lens.append(game_segment_len) |
|
rewards.append(game_segment.reward_segment) |
|
|
|
action_mask_segment.append(game_segment.action_mask_segment) |
|
to_play_segment.append(game_segment.to_play_segment) |
|
|
|
child_visits.append(game_segment.child_visit_segment) |
|
|
|
game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps) |
|
for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): |
|
|
|
if current_index < game_segment_len: |
|
policy_mask.append(1) |
|
beg_index = current_index - state_index |
|
end_index = beg_index + self._cfg.model.frame_stack_num |
|
obs = game_obs[beg_index:end_index] |
|
else: |
|
policy_mask.append(0) |
|
obs = zero_obs |
|
policy_obs_list.append(obs) |
|
|
|
policy_re_context = [ |
|
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, |
|
action_mask_segment, to_play_segment |
|
] |
|
return policy_re_context |
|
|
|
def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> Tuple[Any, Any]: |
|
""" |
|
Overview: |
|
prepare reward and value targets from the context of rewards and values. |
|
Arguments: |
|
- reward_value_context (:obj:'list'): the reward value context |
|
- model (:obj:'torch.tensor'):model of the target model |
|
Returns: |
|
- batch_value_prefixs (:obj:'np.ndarray): batch of value prefix |
|
- batch_target_values (:obj:'np.ndarray): batch of value estimation |
|
""" |
|
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ |
|
to_play_segment = reward_value_context |
|
|
|
transition_batch_size = len(value_obs_list) |
|
game_segment_batch_size = len(pos_in_game_segment_list) |
|
|
|
to_play, action_mask = self._preprocess_to_play_and_action_mask( |
|
game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list |
|
) |
|
if self._cfg.model.continuous_action_space is True: |
|
|
|
action_mask = [ |
|
list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) |
|
] |
|
|
|
legal_actions = [ |
|
[-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) |
|
] |
|
else: |
|
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] |
|
|
|
batch_target_values, batch_rewards = [], [] |
|
with torch.no_grad(): |
|
value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) |
|
|
|
slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) |
|
network_output = [] |
|
for i in range(slices): |
|
beg_index = self._cfg.mini_infer_size * i |
|
end_index = self._cfg.mini_infer_size * (i + 1) |
|
|
|
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() |
|
|
|
|
|
m_output = model.initial_inference(m_obs) |
|
|
|
if not model.training: |
|
|
|
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( |
|
[ |
|
m_output.latent_state, |
|
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), |
|
m_output.policy_logits |
|
] |
|
) |
|
|
|
network_output.append(m_output) |
|
|
|
|
|
if self._cfg.use_root_value: |
|
|
|
|
|
_, reward_pool, policy_logits_pool, latent_state_roots = concat_output( |
|
network_output, data_type='muzero' |
|
) |
|
reward_pool = reward_pool.squeeze().tolist() |
|
policy_logits_pool = policy_logits_pool.tolist() |
|
noises = [ |
|
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) |
|
).astype(np.float32).tolist() for j in range(transition_batch_size) |
|
] |
|
if self._cfg.mcts_ctree: |
|
|
|
roots = MCTSCtree.roots(transition_batch_size, legal_actions) |
|
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) |
|
|
|
MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) |
|
else: |
|
|
|
roots = MCTSPtree.roots(transition_batch_size, legal_actions) |
|
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) |
|
|
|
MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) |
|
|
|
roots_values = roots.get_values() |
|
value_list = np.array(roots_values) |
|
else: |
|
|
|
value_list = concat_output_value(network_output) |
|
|
|
|
|
if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: |
|
|
|
value_list = value_list.reshape(-1) * np.array( |
|
[ |
|
self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % |
|
2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] |
|
for i in range(transition_batch_size) |
|
] |
|
) |
|
else: |
|
value_list = value_list.reshape(-1) * ( |
|
np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list |
|
) |
|
|
|
value_list = value_list * np.array(value_mask) |
|
value_list = value_list.tolist() |
|
horizon_id, value_index = 0, 0 |
|
|
|
for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, |
|
pos_in_game_segment_list, |
|
to_play_segment): |
|
target_values = [] |
|
target_rewards = [] |
|
base_index = state_index |
|
for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): |
|
bootstrap_index = current_index + td_steps_list[value_index] |
|
|
|
for i, reward in enumerate(reward_list[current_index:bootstrap_index]): |
|
if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: |
|
|
|
if to_play_list[base_index] == to_play_list[i]: |
|
value_list[value_index] += reward * self._cfg.discount_factor ** i |
|
else: |
|
value_list[value_index] += -reward * self._cfg.discount_factor ** i |
|
else: |
|
value_list[value_index] += reward * self._cfg.discount_factor ** i |
|
horizon_id += 1 |
|
|
|
if current_index < game_segment_len_non_re: |
|
target_values.append(value_list[value_index]) |
|
target_rewards.append(reward_list[current_index]) |
|
else: |
|
target_values.append(0) |
|
target_rewards.append(0.0) |
|
|
|
|
|
value_index += 1 |
|
|
|
batch_rewards.append(target_rewards) |
|
batch_target_values.append(target_values) |
|
|
|
batch_rewards = np.asarray(batch_rewards, dtype=object) |
|
batch_target_values = np.asarray(batch_target_values, dtype=object) |
|
return batch_rewards, batch_target_values |
|
|
|
def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: |
|
""" |
|
Overview: |
|
prepare policy targets from the reanalyzed context of policies |
|
Arguments: |
|
- policy_re_context (:obj:`List`): List of policy context to reanalyzed |
|
Returns: |
|
- batch_target_policies_re |
|
""" |
|
if policy_re_context is None: |
|
return [] |
|
batch_target_policies_re = [] |
|
|
|
|
|
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ |
|
to_play_segment = policy_re_context |
|
|
|
transition_batch_size = len(policy_obs_list) |
|
game_segment_batch_size = len(pos_in_game_segment_list) |
|
|
|
to_play, action_mask = self._preprocess_to_play_and_action_mask( |
|
game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list |
|
) |
|
|
|
if self._cfg.model.continuous_action_space is True: |
|
|
|
action_mask = [ |
|
list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) |
|
] |
|
|
|
legal_actions = [ |
|
[-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) |
|
] |
|
else: |
|
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] |
|
|
|
with torch.no_grad(): |
|
policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) |
|
|
|
slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) |
|
network_output = [] |
|
for i in range(slices): |
|
beg_index = self._cfg.mini_infer_size * i |
|
end_index = self._cfg.mini_infer_size * (i + 1) |
|
m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float() |
|
m_output = model.initial_inference(m_obs) |
|
if not model.training: |
|
|
|
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( |
|
[ |
|
m_output.latent_state, |
|
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), |
|
m_output.policy_logits |
|
] |
|
) |
|
|
|
network_output.append(m_output) |
|
|
|
_, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero') |
|
reward_pool = reward_pool.squeeze().tolist() |
|
policy_logits_pool = policy_logits_pool.tolist() |
|
noises = [ |
|
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size |
|
).astype(np.float32).tolist() for _ in range(transition_batch_size) |
|
] |
|
if self._cfg.mcts_ctree: |
|
|
|
roots = MCTSCtree.roots(transition_batch_size, legal_actions) |
|
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) |
|
|
|
MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) |
|
else: |
|
|
|
roots = MCTSPtree.roots(transition_batch_size, legal_actions) |
|
roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) |
|
|
|
MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) |
|
|
|
roots_legal_actions_list = legal_actions |
|
roots_distributions = roots.get_distributions() |
|
policy_index = 0 |
|
for state_index, game_index in zip(pos_in_game_segment_list, batch_index_list): |
|
target_policies = [] |
|
|
|
for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): |
|
distributions = roots_distributions[policy_index] |
|
|
|
if policy_mask[policy_index] == 0: |
|
|
|
target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) |
|
else: |
|
if distributions is None: |
|
|
|
target_policies.append( |
|
list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) |
|
) |
|
else: |
|
if self._cfg.action_type == 'fixed_action_space': |
|
|
|
sum_visits = sum(distributions) |
|
policy = [visit_count / sum_visits for visit_count in distributions] |
|
target_policies.append(policy) |
|
else: |
|
|
|
policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] |
|
|
|
sum_visits = sum(distributions) |
|
policy = [visit_count / sum_visits for visit_count in distributions] |
|
for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): |
|
policy_tmp[legal_action] = policy[index] |
|
target_policies.append(policy_tmp) |
|
|
|
policy_index += 1 |
|
|
|
batch_target_policies_re.append(target_policies) |
|
|
|
batch_target_policies_re = np.array(batch_target_policies_re) |
|
|
|
return batch_target_policies_re |
|
|
|
def _compute_target_policy_non_reanalyzed( |
|
self, policy_non_re_context: List[Any], policy_shape: Optional[int] |
|
) -> np.ndarray: |
|
""" |
|
Overview: |
|
prepare policy targets from the non-reanalyzed context of policies |
|
Arguments: |
|
- policy_non_re_context (:obj:`List`): List containing: |
|
- pos_in_game_segment_list |
|
- child_visits |
|
- game_segment_lens |
|
- action_mask_segment |
|
- to_play_segment |
|
- policy_shape: self._cfg.model.action_space_size |
|
Returns: |
|
- batch_target_policies_non_re |
|
""" |
|
batch_target_policies_non_re = [] |
|
if policy_non_re_context is None: |
|
return batch_target_policies_non_re |
|
|
|
pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment = policy_non_re_context |
|
game_segment_batch_size = len(pos_in_game_segment_list) |
|
transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) |
|
|
|
to_play, action_mask = self._preprocess_to_play_and_action_mask( |
|
game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list |
|
) |
|
|
|
if self._cfg.model.continuous_action_space is True: |
|
|
|
action_mask = [ |
|
list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) |
|
] |
|
|
|
legal_actions = [ |
|
[-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) |
|
] |
|
else: |
|
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] |
|
|
|
with torch.no_grad(): |
|
policy_index = 0 |
|
|
|
|
|
policy_mask = [] |
|
for game_segment_len, child_visit, state_index in zip(game_segment_lens, child_visits, |
|
pos_in_game_segment_list): |
|
target_policies = [] |
|
|
|
for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): |
|
if current_index < game_segment_len: |
|
policy_mask.append(1) |
|
|
|
distributions = child_visit[current_index] |
|
if self._cfg.action_type == 'fixed_action_space': |
|
|
|
target_policies.append(distributions) |
|
else: |
|
|
|
policy_tmp = [0 for _ in range(policy_shape)] |
|
for index, legal_action in enumerate(legal_actions[policy_index]): |
|
|
|
policy_tmp[legal_action] = distributions[index] |
|
target_policies.append(policy_tmp) |
|
else: |
|
|
|
policy_mask.append(0) |
|
target_policies.append([0 for _ in range(policy_shape)]) |
|
|
|
policy_index += 1 |
|
|
|
batch_target_policies_non_re.append(target_policies) |
|
batch_target_policies_non_re = np.asarray(batch_target_policies_non_re) |
|
return batch_target_policies_non_re |
|
|
|
def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: |
|
""" |
|
Overview: |
|
Update the priority of training data. |
|
Arguments: |
|
- train_data (:obj:`List[np.ndarray]`): training data to be updated priority. |
|
- batch_priorities (:obj:`batch_priorities`): priorities to update to. |
|
NOTE: |
|
train_data = [current_batch, target_batch] |
|
current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list] |
|
""" |
|
indices = train_data[0][-3] |
|
metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} |
|
|
|
for i in range(len(indices)): |
|
if metas['make_time'][i] > self.clear_time: |
|
idx, prio = indices[i], metas['batch_priorities'][i] |
|
self.game_pos_priorities[idx] = prio |
|
|