|
from typing import Any, Tuple, List |
|
|
|
import numpy as np |
|
from ding.utils import BUFFER_REGISTRY |
|
|
|
from lzero.mcts.utils import prepare_observation |
|
from .game_buffer_muzero import MuZeroGameBuffer |
|
|
|
|
|
@BUFFER_REGISTRY.register('game_buffer_stochastic_muzero') |
|
class StochasticMuZeroGameBuffer(MuZeroGameBuffer): |
|
""" |
|
Overview: |
|
The specific game buffer for Stochastic MuZero policy. |
|
""" |
|
|
|
def __init__(self, cfg: dict): |
|
super().__init__(cfg) |
|
""" |
|
Overview: |
|
Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key |
|
in the default configuration, the user-provided value will override the default configuration. Otherwise, |
|
the default configuration will be used. |
|
""" |
|
default_config = self.default_config() |
|
default_config.update(cfg) |
|
self._cfg = default_config |
|
assert self._cfg.env_type in ['not_board_games', 'board_games'] |
|
assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space'] |
|
self.replay_buffer_size = self._cfg.replay_buffer_size |
|
self.batch_size = self._cfg.batch_size |
|
self._alpha = self._cfg.priority_prob_alpha |
|
self._beta = self._cfg.priority_prob_beta |
|
|
|
self.keep_ratio = 1 |
|
self.model_update_interval = 10 |
|
self.num_of_collected_episodes = 0 |
|
self.base_idx = 0 |
|
self.clear_time = 0 |
|
|
|
self.game_segment_buffer = [] |
|
self.game_pos_priorities = [] |
|
self.game_segment_game_pos_look_up = [] |
|
|
|
def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: |
|
""" |
|
Overview: |
|
first sample orig_data through ``_sample_orig_data()``, |
|
then prepare the context of a batch: |
|
reward_value_context: the context of reanalyzed value targets |
|
policy_re_context: the context of reanalyzed policy targets |
|
policy_non_re_context: the context of non-reanalyzed policy targets |
|
current_batch: the inputs of batch |
|
Arguments: |
|
- batch_size (:obj:`int`): the batch size of orig_data from replay buffer. |
|
- reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) |
|
Returns: |
|
- context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch |
|
""" |
|
|
|
orig_data = self._sample_orig_data(batch_size) |
|
game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data |
|
batch_size = len(batch_index_list) |
|
obs_list, action_list, mask_list = [], [], [] |
|
if self._cfg.use_ture_chance_label_in_chance_encoder: |
|
chance_list = [] |
|
|
|
for i in range(batch_size): |
|
game = game_segment_list[i] |
|
pos_in_game_segment = pos_in_game_segment_list[i] |
|
|
|
actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + |
|
self._cfg.num_unroll_steps].tolist() |
|
if self._cfg.use_ture_chance_label_in_chance_encoder: |
|
chances_tmp = game.chance_segment[1 + pos_in_game_segment:1 + pos_in_game_segment + |
|
self._cfg.num_unroll_steps].tolist() |
|
|
|
mask_tmp = [1. for i in range(len(actions_tmp))] |
|
mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps - len(mask_tmp))] |
|
|
|
|
|
actions_tmp += [ |
|
np.random.randint(0, game.action_space_size) |
|
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) |
|
] |
|
if self._cfg.use_ture_chance_label_in_chance_encoder: |
|
chances_tmp += [ |
|
np.random.randint(0, game.action_space_size) |
|
for _ in range(self._cfg.num_unroll_steps - len(chances_tmp)) |
|
] |
|
|
|
|
|
|
|
obs_list.append( |
|
game_segment_list[i].get_unroll_obs( |
|
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True |
|
) |
|
) |
|
action_list.append(actions_tmp) |
|
mask_list.append(mask_tmp) |
|
if self._cfg.use_ture_chance_label_in_chance_encoder: |
|
chance_list.append(chances_tmp) |
|
|
|
|
|
obs_list = prepare_observation(obs_list, self._cfg.model.model_type) |
|
|
|
|
|
if self._cfg.use_ture_chance_label_in_chance_encoder: |
|
current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list, |
|
chance_list] |
|
else: |
|
current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] |
|
for i in range(len(current_batch)): |
|
current_batch[i] = np.asarray(current_batch[i]) |
|
|
|
total_transitions = self.get_num_of_transitions() |
|
|
|
|
|
reward_value_context = self._prepare_reward_value_context( |
|
batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions |
|
) |
|
""" |
|
only reanalyze recent reanalyze_ratio (e.g. 50%) data |
|
if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps |
|
0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy |
|
""" |
|
reanalyze_num = int(batch_size * reanalyze_ratio) |
|
|
|
if reanalyze_num > 0: |
|
|
|
policy_re_context = self._prepare_policy_reanalyzed_context( |
|
batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num], |
|
pos_in_game_segment_list[:reanalyze_num] |
|
) |
|
else: |
|
policy_re_context = None |
|
|
|
|
|
if reanalyze_num < batch_size: |
|
|
|
policy_non_re_context = self._prepare_policy_non_reanalyzed_context( |
|
batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:], |
|
pos_in_game_segment_list[reanalyze_num:] |
|
) |
|
else: |
|
policy_non_re_context = None |
|
|
|
context = reward_value_context, policy_re_context, policy_non_re_context, current_batch |
|
return context |
|
|
|
def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: |
|
""" |
|
Overview: |
|
Update the priority of training data. |
|
Arguments: |
|
- train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority. |
|
- batch_priorities (:obj:`batch_priorities`): priorities to update to. |
|
NOTE: |
|
train_data = [current_batch, target_batch] |
|
if self._cfg.use_ture_chance_label_in_chance_encoder: |
|
obs_batch_orig, action_batch, mask_batch, indices, weights, make_time, chance_batch = current_batch |
|
else: |
|
obs_batch_orig, action_batch, mask_batch, indices, weights, make_time = current_batch |
|
|
|
""" |
|
indices = train_data[0][3] |
|
metas = {'make_time': train_data[0][5], 'batch_priorities': batch_priorities} |
|
|
|
for i in range(len(indices)): |
|
if metas['make_time'][i] > self.clear_time: |
|
idx, prio = indices[i], metas['batch_priorities'][i] |
|
self.game_pos_priorities[idx] = prio |
|
|