| | from typing import Any, List, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | from ding.utils import BUFFER_REGISTRY |
| |
|
| | from lzero.mcts.tree_search.mcts_ctree_sampled import SampledEfficientZeroMCTSCtree as MCTSCtree |
| | from lzero.mcts.tree_search.mcts_ptree_sampled import SampledEfficientZeroMCTSPtree as MCTSPtree |
| | from lzero.mcts.utils import prepare_observation, generate_random_actions_discrete |
| | from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform |
| | from .game_buffer_efficientzero import EfficientZeroGameBuffer |
| |
|
| |
|
| | @BUFFER_REGISTRY.register('game_buffer_sampled_efficientzero') |
| | class SampledEfficientZeroGameBuffer(EfficientZeroGameBuffer): |
| | """ |
| | Overview: |
| | The specific game buffer for Sampled EfficientZero policy. |
| | """ |
| |
|
| | def __init__(self, cfg: dict): |
| | super().__init__(cfg) |
| | """ |
| | Overview: |
| | Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key |
| | in the default configuration, the user-provided value will override the default configuration. Otherwise, |
| | the default configuration will be used. |
| | """ |
| | default_config = self.default_config() |
| | default_config.update(cfg) |
| | self._cfg = default_config |
| | assert self._cfg.env_type in ['not_board_games', 'board_games'] |
| | assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space'] |
| | self.replay_buffer_size = self._cfg.replay_buffer_size |
| | self.batch_size = self._cfg.batch_size |
| | self._alpha = self._cfg.priority_prob_alpha |
| | self._beta = self._cfg.priority_prob_beta |
| |
|
| | self.game_segment_buffer = [] |
| | self.game_pos_priorities = [] |
| | self.game_segment_game_pos_look_up = [] |
| |
|
| | self.keep_ratio = 1 |
| | self.num_of_collected_episodes = 0 |
| | self.base_idx = 0 |
| | self.clear_time = 0 |
| |
|
| | def sample(self, batch_size: int, policy: Any) -> List[Any]: |
| | """ |
| | Overview: |
| | sample data from ``GameBuffer`` and prepare the current and target batch for training |
| | Arguments: |
| | - batch_size (:obj:`int`): batch size |
| | - policy (:obj:`torch.tensor`): model of policy |
| | Returns: |
| | - train_data (:obj:`List`): List of train data |
| | """ |
| |
|
| | policy._target_model.to(self._cfg.device) |
| | policy._target_model.eval() |
| |
|
| | reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( |
| | batch_size, self._cfg.reanalyze_ratio |
| | ) |
| |
|
| | |
| | batch_value_prefixs, batch_target_values = self._compute_target_reward_value( |
| | reward_value_context, policy._target_model |
| | ) |
| |
|
| | batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( |
| | policy_non_re_context, self._cfg.model.num_of_sampled_actions |
| | ) |
| |
|
| | if self._cfg.reanalyze_ratio > 0: |
| | |
| | batch_target_policies_re, root_sampled_actions = self._compute_target_policy_reanalyzed( |
| | policy_re_context, policy._target_model |
| | ) |
| | |
| | |
| | |
| | |
| | |
| |
|
| | assert (self._cfg.reanalyze_ratio > 0 and self._cfg.reanalyze_outdated is True), \ |
| | "in sampled effiicientzero, if self._cfg.reanalyze_ratio>0, you must set self._cfg.reanalyze_outdated=True" |
| | |
| | if self._cfg.model.continuous_action_space: |
| | current_batch[2][:int(batch_size * self._cfg.reanalyze_ratio)] = root_sampled_actions.reshape( |
| | int(batch_size * self._cfg.reanalyze_ratio), self._cfg.num_unroll_steps + 1, |
| | self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size |
| | ) |
| | else: |
| | current_batch[2][:int(batch_size * self._cfg.reanalyze_ratio)] = root_sampled_actions.reshape( |
| | int(batch_size * self._cfg.reanalyze_ratio), self._cfg.num_unroll_steps + 1, |
| | self._cfg.model.num_of_sampled_actions, 1 |
| | ) |
| |
|
| | if 0 < self._cfg.reanalyze_ratio < 1: |
| | try: |
| | batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) |
| | except Exception as error: |
| | print(error) |
| | elif self._cfg.reanalyze_ratio == 1: |
| | batch_target_policies = batch_target_policies_re |
| | elif self._cfg.reanalyze_ratio == 0: |
| | batch_target_policies = batch_target_policies_non_re |
| |
|
| | target_batch = [batch_value_prefixs, batch_target_values, batch_target_policies] |
| | |
| | train_data = [current_batch, target_batch] |
| | return train_data |
| |
|
| | def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: |
| | """ |
| | Overview: |
| | first sample orig_data through ``_sample_orig_data()``, |
| | then prepare the context of a batch: |
| | reward_value_context: the context of reanalyzed value targets |
| | policy_re_context: the context of reanalyzed policy targets |
| | policy_non_re_context: the context of non-reanalyzed policy targets |
| | current_batch: the inputs of batch |
| | Arguments: |
| | - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. |
| | - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) |
| | Returns: |
| | - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch |
| | """ |
| | |
| | orig_data = self._sample_orig_data(batch_size) |
| | game_lst, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data |
| | batch_size = len(batch_index_list) |
| | obs_list, action_list, mask_list = [], [], [] |
| | root_sampled_actions_list = [] |
| | |
| | for i in range(batch_size): |
| | game = game_lst[i] |
| | pos_in_game_segment = pos_in_game_segment_list[i] |
| | |
| | |
| | |
| | actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + |
| | self._cfg.num_unroll_steps].tolist() |
| |
|
| | |
| | root_sampled_actions_tmp = game.root_sampled_actions[pos_in_game_segment:pos_in_game_segment + |
| | self._cfg.num_unroll_steps + 1] |
| |
|
| | |
| | mask_tmp = [1. for i in range(len(root_sampled_actions_tmp))] |
| | mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] |
| |
|
| | |
| | if self._cfg.model.continuous_action_space: |
| | actions_tmp += [ |
| | np.random.randn(self._cfg.model.action_space_size) |
| | for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) |
| | ] |
| | root_sampled_actions_tmp += [ |
| | np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) |
| | for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) |
| | ] |
| | else: |
| | |
| | actions_tmp += generate_random_actions_discrete( |
| | self._cfg.num_unroll_steps - len(actions_tmp), |
| | self._cfg.model.action_space_size, |
| | 1 |
| | ) |
| |
|
| | |
| | |
| | reshape = True if self._cfg.mcts_ctree else False |
| | root_sampled_actions_tmp += generate_random_actions_discrete( |
| | self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp), |
| | self._cfg.model.action_space_size, |
| | self._cfg.model.num_of_sampled_actions, |
| | reshape=reshape |
| | ) |
| |
|
| | |
| | |
| | |
| | obs_list.append( |
| | game_lst[i].get_unroll_obs( |
| | pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True |
| | ) |
| | ) |
| | action_list.append(actions_tmp) |
| | root_sampled_actions_list.append(root_sampled_actions_tmp) |
| |
|
| | mask_list.append(mask_tmp) |
| |
|
| | |
| | obs_list = prepare_observation(obs_list, self._cfg.model.model_type) |
| | |
| | |
| | |
| | |
| | current_batch = [ |
| | obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list |
| | ] |
| |
|
| | for i in range(len(current_batch)): |
| | current_batch[i] = np.asarray(current_batch[i]) |
| |
|
| | total_transitions = self.get_num_of_transitions() |
| |
|
| | |
| | reward_value_context = self._prepare_reward_value_context( |
| | batch_index_list, game_lst, pos_in_game_segment_list, total_transitions |
| | ) |
| | """ |
| | only reanalyze recent reanalyze_ratio (e.g. 50%) data |
| | if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps |
| | 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy |
| | """ |
| | reanalyze_num = int(batch_size * reanalyze_ratio) |
| | |
| | if reanalyze_num > 0: |
| | |
| | policy_re_context = self._prepare_policy_reanalyzed_context( |
| | batch_index_list[:reanalyze_num], game_lst[:reanalyze_num], pos_in_game_segment_list[:reanalyze_num] |
| | ) |
| | else: |
| | policy_re_context = None |
| |
|
| | |
| | if reanalyze_num < batch_size: |
| | |
| | policy_non_re_context = self._prepare_policy_non_reanalyzed_context( |
| | batch_index_list[reanalyze_num:], game_lst[reanalyze_num:], pos_in_game_segment_list[reanalyze_num:] |
| | ) |
| | else: |
| | policy_non_re_context = None |
| |
|
| | context = reward_value_context, policy_re_context, policy_non_re_context, current_batch |
| | return context |
| |
|
| | def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> List[np.ndarray]: |
| | """ |
| | Overview: |
| | prepare reward and value targets from the context of rewards and values. |
| | Arguments: |
| | - reward_value_context (:obj:'list'): the reward value context |
| | - model (:obj:'torch.tensor'):model of the target model |
| | Returns: |
| | - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix |
| | - batch_target_values (:obj:'np.ndarray): batch of value estimation |
| | """ |
| | value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ |
| | to_play_segment = reward_value_context |
| |
|
| | |
| | transition_batch_size = len(value_obs_list) |
| | game_segment_batch_size = len(pos_in_game_segment_list) |
| |
|
| | to_play, action_mask = self._preprocess_to_play_and_action_mask( |
| | game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list |
| | ) |
| | if self._cfg.model.continuous_action_space is True: |
| | |
| | action_mask = [ |
| | list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) |
| | ] |
| | |
| | legal_actions = [ |
| | [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) |
| | ] |
| | else: |
| | legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] |
| |
|
| | batch_target_values, batch_value_prefixs = [], [] |
| | with torch.no_grad(): |
| | value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) |
| | |
| | slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) |
| | network_output = [] |
| | for i in range(slices): |
| | beg_index = self._cfg.mini_infer_size * i |
| | end_index = self._cfg.mini_infer_size * (i + 1) |
| | m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() |
| |
|
| | |
| | m_output = model.initial_inference(m_obs) |
| |
|
| | |
| | if not model.training: |
| | |
| | [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( |
| | [ |
| | m_output.latent_state, |
| | inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), |
| | m_output.policy_logits |
| | ] |
| | ) |
| | m_output.reward_hidden_state = ( |
| | m_output.reward_hidden_state[0].detach().cpu().numpy(), |
| | m_output.reward_hidden_state[1].detach().cpu().numpy() |
| | ) |
| |
|
| | network_output.append(m_output) |
| |
|
| | |
| | if self._cfg.use_root_value: |
| | |
| | |
| | _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output( |
| | network_output, data_type='efficientzero' |
| | ) |
| | value_prefix_pool = value_prefix_pool.squeeze().tolist() |
| | policy_logits_pool = policy_logits_pool.tolist() |
| | |
| | noises = [ |
| | np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.num_of_sampled_actions |
| | ).astype(np.float32).tolist() for _ in range(transition_batch_size) |
| | ] |
| |
|
| | if self._cfg.mcts_ctree: |
| | |
| | |
| | roots = MCTSCtree.roots( |
| | transition_batch_size, legal_actions, self._cfg.model.action_space_size, |
| | self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space |
| | ) |
| |
|
| | roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) |
| | |
| | MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) |
| | else: |
| | |
| | roots = MCTSPtree.roots( |
| | transition_batch_size, legal_actions, self._cfg.model.action_space_size, |
| | self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space |
| | ) |
| | roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) |
| | |
| | MCTSPtree.roots(self._cfg |
| | ).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) |
| |
|
| | roots_values = roots.get_values() |
| | value_list = np.array(roots_values) |
| | else: |
| | |
| | value_list = concat_output_value(network_output) |
| |
|
| | |
| | if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: |
| | |
| | value_list = value_list.reshape(-1) * np.array( |
| | [ |
| | self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % |
| | 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] |
| | for i in range(transition_batch_size) |
| | ] |
| | ) |
| | else: |
| | value_list = value_list.reshape(-1) * ( |
| | np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list |
| | ) |
| |
|
| | value_list = value_list * np.array(value_mask) |
| | value_list = value_list.tolist() |
| |
|
| | horizon_id, value_index = 0, 0 |
| | for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, |
| | pos_in_game_segment_list, |
| | to_play_segment): |
| | target_values = [] |
| | target_value_prefixs = [] |
| |
|
| | value_prefix = 0.0 |
| | base_index = state_index |
| | for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): |
| | bootstrap_index = current_index + td_steps_list[value_index] |
| | |
| | for i, reward in enumerate(reward_list[current_index:bootstrap_index]): |
| | if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: |
| | |
| | if to_play_list[base_index] == to_play_list[i]: |
| | value_list[value_index] += reward * self._cfg.discount_factor ** i |
| | else: |
| | value_list[value_index] += -reward * self._cfg.discount_factor ** i |
| | else: |
| | value_list[value_index] += reward * self._cfg.discount_factor ** i |
| | |
| |
|
| | |
| | if horizon_id % self._cfg.lstm_horizon_len == 0: |
| | value_prefix = 0.0 |
| | base_index = current_index |
| | horizon_id += 1 |
| |
|
| | if current_index < game_segment_len_non_re: |
| | target_values.append(value_list[value_index]) |
| | |
| | |
| | value_prefix += reward_list[current_index |
| | ] |
| | target_value_prefixs.append(value_prefix) |
| | else: |
| | target_values.append(0) |
| | target_value_prefixs.append(value_prefix) |
| |
|
| | value_index += 1 |
| |
|
| | batch_value_prefixs.append(target_value_prefixs) |
| | batch_target_values.append(target_values) |
| |
|
| | batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=object) |
| | batch_target_values = np.asarray(batch_target_values, dtype=object) |
| |
|
| | return batch_value_prefixs, batch_target_values |
| |
|
| | def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: |
| | """ |
| | Overview: |
| | prepare policy targets from the reanalyzed context of policies |
| | Arguments: |
| | - policy_re_context (:obj:`List`): List of policy context to reanalyzed |
| | Returns: |
| | - batch_target_policies_re |
| | """ |
| | if policy_re_context is None: |
| | return [] |
| | batch_target_policies_re = [] |
| |
|
| | policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ |
| | to_play_segment = policy_re_context |
| | |
| | transition_batch_size = len(policy_obs_list) |
| | game_segment_batch_size = len(pos_in_game_segment_list) |
| |
|
| | to_play, action_mask = self._preprocess_to_play_and_action_mask( |
| | game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list |
| | ) |
| | if self._cfg.model.continuous_action_space is True: |
| | |
| | action_mask = [ |
| | list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) |
| | ] |
| | |
| | legal_actions = [ |
| | [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) |
| | ] |
| | else: |
| | legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] |
| |
|
| | with torch.no_grad(): |
| | policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) |
| | |
| | self._cfg.mini_infer_size = self._cfg.mini_infer_size |
| | slices = np.ceil(transition_batch_size / self._cfg.mini_infer_size).astype(np.int_) |
| | network_output = [] |
| | for i in range(slices): |
| | beg_index = self._cfg.mini_infer_size * i |
| | end_index = self._cfg.mini_infer_size * (i + 1) |
| | m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float() |
| |
|
| | m_output = model.initial_inference(m_obs) |
| |
|
| | if not model.training: |
| | |
| | [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( |
| | [ |
| | m_output.latent_state, |
| | inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), |
| | m_output.policy_logits |
| | ] |
| | ) |
| | m_output.reward_hidden_state = ( |
| | m_output.reward_hidden_state[0].detach().cpu().numpy(), |
| | m_output.reward_hidden_state[1].detach().cpu().numpy() |
| | ) |
| |
|
| | network_output.append(m_output) |
| |
|
| | _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output( |
| | network_output, data_type='efficientzero' |
| | ) |
| |
|
| | value_prefix_pool = value_prefix_pool.squeeze().tolist() |
| | policy_logits_pool = policy_logits_pool.tolist() |
| | noises = [ |
| | np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.num_of_sampled_actions |
| | ).astype(np.float32).tolist() for _ in range(transition_batch_size) |
| | ] |
| | if self._cfg.mcts_ctree: |
| | |
| | |
| | |
| | |
| | roots = MCTSCtree.roots( |
| | transition_batch_size, legal_actions, self._cfg.model.action_space_size, |
| | self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space |
| | ) |
| | roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) |
| | |
| | MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) |
| | else: |
| | |
| | roots = MCTSPtree.roots( |
| | transition_batch_size, legal_actions, self._cfg.model.action_space_size, |
| | self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space |
| | ) |
| | roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) |
| | |
| | MCTSPtree.roots(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play) |
| |
|
| | roots_legal_actions_list = legal_actions |
| | roots_distributions = roots.get_distributions() |
| |
|
| | |
| | |
| | |
| | roots_sampled_actions = roots.get_sampled_actions() |
| | try: |
| | root_sampled_actions = np.array([action.value for action in roots_sampled_actions]) |
| | except Exception: |
| | root_sampled_actions = np.array([action for action in roots_sampled_actions]) |
| |
|
| | policy_index = 0 |
| | for state_index, game_idx in zip(pos_in_game_segment_list, batch_index_list): |
| | target_policies = [] |
| | for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): |
| | distributions = roots_distributions[policy_index] |
| | |
| | |
| | |
| | if policy_mask[policy_index] == 0: |
| | |
| | target_policies.append([0 for _ in range(self._cfg.model.num_of_sampled_actions)]) |
| | else: |
| | if distributions is None: |
| | |
| | target_policies.append( |
| | list( |
| | np.ones(self._cfg.model.num_of_sampled_actions) / |
| | self._cfg.model.num_of_sampled_actions |
| | ) |
| | ) |
| | else: |
| | if self._cfg.action_type == 'fixed_action_space': |
| | sum_visits = sum(distributions) |
| | policy = [visit_count / sum_visits for visit_count in distributions] |
| | target_policies.append(policy) |
| | else: |
| | |
| | policy_tmp = [0 for _ in range(self._cfg.model.num_of_sampled_actions)] |
| | |
| | sum_visits = sum(distributions) |
| | policy = [visit_count / sum_visits for visit_count in distributions] |
| | for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): |
| | policy_tmp[legal_action] = policy[index] |
| | target_policies.append(policy_tmp) |
| |
|
| | policy_index += 1 |
| |
|
| | batch_target_policies_re.append(target_policies) |
| |
|
| | batch_target_policies_re = np.array(batch_target_policies_re) |
| |
|
| | return batch_target_policies_re, root_sampled_actions |
| |
|
| | def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: |
| | """ |
| | Overview: |
| | Update the priority of training data. |
| | Arguments: |
| | - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority. |
| | - batch_priorities (:obj:`batch_priorities`): priorities to update to. |
| | NOTE: |
| | train_data = [current_batch, target_batch] |
| | current_batch = [obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list] |
| | """ |
| |
|
| | batch_index_list = train_data[0][4] |
| | metas = {'make_time': train_data[0][6], 'batch_priorities': batch_priorities} |
| | |
| | for i in range(len(batch_index_list)): |
| | if metas['make_time'][i] > self.clear_time: |
| | idx, prio = batch_index_list[i], metas['batch_priorities'][i] |
| | self.game_pos_priorities[idx] = prio |
| |
|