|
import copy |
|
from typing import List, Tuple |
|
|
|
import numpy as np |
|
from easydict import EasyDict |
|
|
|
from ding.utils.compression_helper import jpeg_data_decompressor |
|
|
|
|
|
class GameSegment: |
|
""" |
|
Overview: |
|
A game segment from a full episode trajectory. |
|
|
|
The length of one episode in (Atari) games is often quite large. This class represents a single game segment |
|
within a larger trajectory, split into several blocks. |
|
|
|
Interfaces: |
|
- __init__ |
|
- __len__ |
|
- reset |
|
- pad_over |
|
- is_full |
|
- legal_actions |
|
- append |
|
- get_observation |
|
- zero_obs |
|
- step_obs |
|
- get_targets |
|
- game_segment_to_array |
|
- store_search_stats |
|
""" |
|
|
|
def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None: |
|
""" |
|
Overview: |
|
Init the ``GameSegment`` according to the provided arguments. |
|
Arguments: |
|
action_space (:obj:`int`): action space |
|
- game_segment_length (:obj:`int`): the transition number of one ``GameSegment`` block |
|
""" |
|
self.action_space = action_space |
|
self.game_segment_length = game_segment_length |
|
self.num_unroll_steps = config.num_unroll_steps |
|
self.td_steps = config.td_steps |
|
self.frame_stack_num = config.model.frame_stack_num |
|
self.discount_factor = config.discount_factor |
|
self.action_space_size = config.model.action_space_size |
|
self.gray_scale = config.gray_scale |
|
self.transform2string = config.transform2string |
|
self.sampled_algo = config.sampled_algo |
|
self.gumbel_algo = config.gumbel_algo |
|
self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder |
|
|
|
if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: |
|
|
|
self.zero_obs_shape = config.model.observation_shape |
|
elif len(config.model.observation_shape) == 3: |
|
|
|
self.zero_obs_shape = ( |
|
config.model.observation_shape[-2], config.model.observation_shape[-1], config.model.image_channel |
|
) |
|
|
|
self.obs_segment = [] |
|
self.action_segment = [] |
|
self.reward_segment = [] |
|
|
|
self.child_visit_segment = [] |
|
self.root_value_segment = [] |
|
|
|
self.action_mask_segment = [] |
|
self.to_play_segment = [] |
|
|
|
self.target_values = [] |
|
self.target_rewards = [] |
|
self.target_policies = [] |
|
|
|
self.improved_policy_probs = [] |
|
|
|
if self.sampled_algo: |
|
self.root_sampled_actions = [] |
|
if self.use_ture_chance_label_in_chance_encoder: |
|
self.chance_segment = [] |
|
|
|
|
|
def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: |
|
""" |
|
Overview: |
|
Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps]. |
|
Arguments: |
|
- timestep (int): The time step. |
|
- num_unroll_steps (int): The extra length of the observation frames. |
|
- padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory. |
|
""" |
|
stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps] |
|
if padding: |
|
pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_obs) |
|
if pad_len > 0: |
|
pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)]) |
|
stacked_obs = np.concatenate((stacked_obs, pad_frames)) |
|
if self.transform2string: |
|
stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] |
|
return stacked_obs |
|
|
|
def zero_obs(self) -> List: |
|
""" |
|
Overview: |
|
Return an observation frame filled with zeros. |
|
Returns: |
|
ndarray: An array filled with zeros. |
|
""" |
|
return [np.zeros(self.zero_obs_shape, dtype=np.float32) for _ in range(self.frame_stack_num)] |
|
|
|
def get_obs(self) -> List: |
|
""" |
|
Overview: |
|
Return an observation in the correct format for model inference. |
|
Returns: |
|
stacked_obs (List): An observation in the correct format for model inference. |
|
""" |
|
timestep_obs = len(self.obs_segment) - self.frame_stack_num |
|
timestep_reward = len(self.reward_segment) |
|
assert timestep_obs == timestep_reward, "timestep_obs: {}, timestep_reward: {}".format( |
|
timestep_obs, timestep_reward |
|
) |
|
timestep = timestep_reward |
|
stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num] |
|
if self.transform2string: |
|
stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] |
|
return stacked_obs |
|
|
|
def append( |
|
self, |
|
action: np.ndarray, |
|
obs: np.ndarray, |
|
reward: np.ndarray, |
|
action_mask: np.ndarray = None, |
|
to_play: int = -1, |
|
chance: int = 0, |
|
) -> None: |
|
""" |
|
Overview: |
|
Append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t}. |
|
""" |
|
self.action_segment.append(action) |
|
self.obs_segment.append(obs) |
|
self.reward_segment.append(reward) |
|
|
|
self.action_mask_segment.append(action_mask) |
|
self.to_play_segment.append(to_play) |
|
if self.use_ture_chance_label_in_chance_encoder: |
|
self.chance_segment.append(chance) |
|
|
|
def pad_over( |
|
self, next_segment_observations: List, next_segment_rewards: List, next_segment_root_values: List, |
|
next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None, |
|
) -> None: |
|
""" |
|
Overview: |
|
To make sure the correction of value targets, we need to add (o_t, r_t, etc) from the next game_segment |
|
, which is necessary for the bootstrapped values at the end states of previous game_segment. |
|
e.g: len = 100; target value v_100 = r_100 + gamma^1 r_101 + ... + gamma^4 r_104 + gamma^5 v_105, |
|
but r_101, r_102, ... are from the next game_segment. |
|
Arguments: |
|
- next_segment_observations (:obj:`list`): o_t from the next game_segment |
|
- next_segment_rewards (:obj:`list`): r_t from the next game_segment |
|
- next_segment_root_values (:obj:`list`): root values of MCTS from the next game_segment |
|
- next_segment_child_visits (:obj:`list`): root visit count distributions of MCTS from the next game_segment |
|
- next_segment_improved_policy (:obj:`list`): root children select policy of MCTS from the next game_segment (Only used in Gumbel MuZero) |
|
""" |
|
assert len(next_segment_observations) <= self.num_unroll_steps |
|
assert len(next_segment_child_visits) <= self.num_unroll_steps |
|
assert len(next_segment_root_values) <= self.num_unroll_steps + self.td_steps |
|
assert len(next_segment_rewards) <= self.num_unroll_steps + self.td_steps - 1 |
|
|
|
|
|
|
|
if self.gumbel_algo: |
|
assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.td_steps |
|
|
|
|
|
for observation in next_segment_observations: |
|
self.obs_segment.append(copy.deepcopy(observation)) |
|
|
|
for reward in next_segment_rewards: |
|
self.reward_segment.append(reward) |
|
|
|
for value in next_segment_root_values: |
|
self.root_value_segment.append(value) |
|
|
|
for child_visits in next_segment_child_visits: |
|
self.child_visit_segment.append(child_visits) |
|
|
|
if self.gumbel_algo: |
|
for improved_policy in next_segment_improved_policy: |
|
self.improved_policy_probs.append(improved_policy) |
|
if self.use_ture_chance_label_in_chance_encoder: |
|
for chances in next_chances: |
|
self.chance_segment.append(chances) |
|
|
|
def get_targets(self, timestep: int) -> Tuple: |
|
""" |
|
Overview: |
|
return the value/reward/policy targets at step timestep |
|
""" |
|
return self.target_values[timestep], self.target_rewards[timestep], self.target_policies[timestep] |
|
|
|
def store_search_stats( |
|
self, visit_counts: List, root_value: List, root_sampled_actions: List = None, improved_policy: List = None, idx: int = None |
|
) -> None: |
|
""" |
|
Overview: |
|
store the visit count distributions and value of the root node after MCTS. |
|
""" |
|
sum_visits = sum(visit_counts) |
|
if idx is None: |
|
self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts]) |
|
self.root_value_segment.append(root_value) |
|
if self.sampled_algo: |
|
self.root_sampled_actions.append(root_sampled_actions) |
|
|
|
if self.gumbel_algo: |
|
self.improved_policy_probs.append(improved_policy) |
|
else: |
|
self.child_visit_segment[idx] = [visit_count / sum_visits for visit_count in visit_counts] |
|
self.root_value_segment[idx] = root_value |
|
self.improved_policy_probs[idx] = improved_policy |
|
|
|
def game_segment_to_array(self) -> None: |
|
""" |
|
Overview: |
|
Post-process the data when a `GameSegment` block is full. This function converts various game segment |
|
elements into numpy arrays for easier manipulation and processing. |
|
Structure: |
|
The structure and shapes of different game segment elements are as follows. Let's assume |
|
`game_segment_length`=20, `stack`=4, `num_unroll_steps`=5, `td_steps`=5: |
|
|
|
- obs: game_segment_length + stack + num_unroll_steps, 20+4+5 |
|
- action: game_segment_length -> 20 |
|
- reward: game_segment_length + num_unroll_steps + td_steps -1 20+5+5-1 |
|
- root_values: game_segment_length + num_unroll_steps + td_steps -> 20+5+5 |
|
- child_visits: game_segment_length + num_unroll_steps -> 20+5 |
|
- to_play: game_segment_length -> 20 |
|
- action_mask: game_segment_length -> 20 |
|
Examples: |
|
Here is an illustration of the structure of `obs` and `rew` for two consecutive game segments |
|
(game_segment_i and game_segment_i+1): |
|
|
|
- game_segment_i (obs): 4 20 5 |
|
----|----...----|-----| |
|
- game_segment_i+1 (obs): 4 20 5 |
|
----|----...----|-----| |
|
|
|
- game_segment_i (rew): 20 5 4 |
|
----...----|------|-----| |
|
- game_segment_i+1 (rew): 20 5 4 |
|
----...----|------|-----| |
|
|
|
Postprocessing: |
|
- self.obs_segment (:obj:`numpy.ndarray`): A numpy array version of the original obs_segment. |
|
- self.action_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_segment. |
|
- self.reward_segment (:obj:`numpy.ndarray`): A numpy array version of the original reward_segment. |
|
- self.child_visit_segment (:obj:`numpy.ndarray`): A numpy array version of the original child_visit_segment. |
|
- self.root_value_segment (:obj:`numpy.ndarray`): A numpy array version of the original root_value_segment. |
|
- self.improved_policy_probs (:obj:`numpy.ndarray`): A numpy array version of the original improved_policy_probs. |
|
- self.action_mask_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_mask_segment. |
|
- self.to_play_segment (:obj:`numpy.ndarray`): A numpy array version of the original to_play_segment. |
|
- self.chance_segment (:obj:`numpy.ndarray`, optional): A numpy array version of the original chance_segment. Only |
|
created if `self.use_ture_chance_label_in_chance_encoder` is True. |
|
|
|
.. note:: |
|
For environments with a variable action space, such as board games, the elements in `child_visit_segment` may have |
|
different lengths. In such scenarios, it is necessary to use the object data type for `self.child_visit_segment`. |
|
""" |
|
self.obs_segment = np.array(self.obs_segment) |
|
self.action_segment = np.array(self.action_segment) |
|
self.reward_segment = np.array(self.reward_segment) |
|
|
|
|
|
if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment): |
|
self.child_visit_segment = np.array(self.child_visit_segment) |
|
else: |
|
|
|
|
|
|
|
self.child_visit_segment = np.array(self.child_visit_segment, dtype=object) |
|
|
|
self.root_value_segment = np.array(self.root_value_segment) |
|
self.improved_policy_probs = np.array(self.improved_policy_probs) |
|
|
|
self.action_mask_segment = np.array(self.action_mask_segment) |
|
self.to_play_segment = np.array(self.to_play_segment) |
|
if self.use_ture_chance_label_in_chance_encoder: |
|
self.chance_segment = np.array(self.chance_segment) |
|
|
|
def reset(self, init_observations: np.ndarray) -> None: |
|
""" |
|
Overview: |
|
Initialize the game segment using ``init_observations``, |
|
which is the previous ``frame_stack_num`` stacked frames. |
|
Arguments: |
|
- init_observations (:obj:`list`): list of the stack observations in the previous time steps. |
|
""" |
|
self.obs_segment = [] |
|
self.action_segment = [] |
|
self.reward_segment = [] |
|
|
|
self.child_visit_segment = [] |
|
self.root_value_segment = [] |
|
|
|
self.action_mask_segment = [] |
|
self.to_play_segment = [] |
|
if self.use_ture_chance_label_in_chance_encoder: |
|
self.chance_segment = [] |
|
|
|
assert len(init_observations) == self.frame_stack_num |
|
|
|
for observation in init_observations: |
|
self.obs_segment.append(copy.deepcopy(observation)) |
|
|
|
def is_full(self) -> bool: |
|
""" |
|
Overview: |
|
Check whether the current game segment is full, i.e. larger than the segment length. |
|
Returns: |
|
bool: True if the game segment is full, False otherwise. |
|
""" |
|
return len(self.action_segment) >= self.game_segment_length |
|
|
|
def legal_actions(self): |
|
return [_ for _ in range(self.action_space.n)] |
|
|
|
def __len__(self): |
|
return len(self.action_segment) |
|
|