|
from typing import Any, Tuple |
|
|
|
import numpy as np |
|
from ding.utils import BUFFER_REGISTRY |
|
|
|
from lzero.mcts.buffer import MuZeroGameBuffer |
|
from lzero.mcts.utils import prepare_observation |
|
|
|
|
|
@BUFFER_REGISTRY.register('game_buffer_gumbel_muzero') |
|
class GumbelMuZeroGameBuffer(MuZeroGameBuffer): |
|
""" |
|
Overview: |
|
The specific game buffer for Gumbel MuZero policy. |
|
""" |
|
|
|
def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: |
|
""" |
|
Overview: |
|
first sample orig_data through ``_sample_orig_data()``, |
|
then prepare the context of a batch: |
|
reward_value_context: the context of reanalyzed value targets |
|
policy_re_context: the context of reanalyzed policy targets |
|
policy_non_re_context: the context of non-reanalyzed policy targets |
|
current_batch: the inputs of batch |
|
Arguments: |
|
- batch_size (:obj:`int`): the batch size of orig_data from replay buffer. |
|
- reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) |
|
Returns: |
|
- context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch |
|
""" |
|
|
|
orig_data = self._sample_orig_data(batch_size) |
|
game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data |
|
batch_size = len(batch_index_list) |
|
|
|
|
|
|
|
|
|
|
|
obs_list, action_list, improved_policy_list, mask_list = [], [], [], [] |
|
|
|
for i in range(batch_size): |
|
game = game_segment_list[i] |
|
pos_in_game_segment = pos_in_game_segment_list[i] |
|
|
|
actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + |
|
self._cfg.num_unroll_steps].tolist() |
|
|
|
_improved_policy = game.improved_policy_probs[ |
|
pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps] |
|
if not isinstance(_improved_policy, list): |
|
_improved_policy = _improved_policy.tolist() |
|
|
|
|
|
mask_tmp = [1. for i in range(len(actions_tmp))] |
|
mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] |
|
|
|
|
|
actions_tmp += [ |
|
np.random.randint(0, game.action_space_size) |
|
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) |
|
] |
|
|
|
|
|
_improved_policy.extend(np.random.dirichlet(np.ones(game.action_space_size), |
|
size=self._cfg.num_unroll_steps + 1 - len(_improved_policy))) |
|
|
|
|
|
|
|
|
|
obs_list.append( |
|
game_segment_list[i].get_unroll_obs( |
|
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True |
|
) |
|
) |
|
action_list.append(actions_tmp) |
|
improved_policy_list.append(_improved_policy) |
|
mask_list.append(mask_tmp) |
|
|
|
|
|
obs_list = prepare_observation(obs_list, self._cfg.model.model_type) |
|
|
|
|
|
current_batch = [obs_list, action_list, improved_policy_list, mask_list, batch_index_list, weights_list, |
|
make_time_list] |
|
for i in range(len(current_batch)): |
|
current_batch[i] = np.asarray(current_batch[i]) |
|
|
|
total_transitions = self.get_num_of_transitions() |
|
|
|
|
|
reward_value_context = self._prepare_reward_value_context( |
|
batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions |
|
) |
|
""" |
|
only reanalyze recent reanalyze_ratio (e.g. 50%) data |
|
if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps |
|
0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy |
|
""" |
|
reanalyze_num = int(batch_size * reanalyze_ratio) |
|
|
|
if reanalyze_num > 0: |
|
|
|
policy_re_context = self._prepare_policy_reanalyzed_context( |
|
batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num], |
|
pos_in_game_segment_list[:reanalyze_num] |
|
) |
|
else: |
|
policy_re_context = None |
|
|
|
|
|
if reanalyze_num < batch_size: |
|
|
|
policy_non_re_context = self._prepare_policy_non_reanalyzed_context( |
|
batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:], |
|
pos_in_game_segment_list[reanalyze_num:] |
|
) |
|
else: |
|
policy_non_re_context = None |
|
|
|
context = reward_value_context, policy_re_context, policy_non_re_context, current_batch |
|
return context |
|
|