gomoku / LightZero /lzero /mcts /buffer /game_segment.py
zjowowen's picture
init space
history blame
15.5 kB
import copy
from typing import List, Tuple
import numpy as np
from easydict import EasyDict
from ding.utils.compression_helper import jpeg_data_decompressor
class GameSegment:
A game segment from a full episode trajectory.
The length of one episode in (Atari) games is often quite large. This class represents a single game segment
within a larger trajectory, split into several blocks.
- __init__
- __len__
- reset
- pad_over
- is_full
- legal_actions
- append
- get_observation
- zero_obs
- step_obs
- get_targets
- game_segment_to_array
- store_search_stats
def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None:
Init the ``GameSegment`` according to the provided arguments.
action_space (:obj:`int`): action space
- game_segment_length (:obj:`int`): the transition number of one ``GameSegment`` block
self.action_space = action_space
self.game_segment_length = game_segment_length
self.num_unroll_steps = config.num_unroll_steps
self.td_steps = config.td_steps
self.frame_stack_num = config.model.frame_stack_num
self.discount_factor = config.discount_factor
self.action_space_size = config.model.action_space_size
self.gray_scale = config.gray_scale
self.transform2string = config.transform2string
self.sampled_algo = config.sampled_algo
self.gumbel_algo = config.gumbel_algo
self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder
if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1:
# for vector obs input, e.g. classical control and box2d environments
self.zero_obs_shape = config.model.observation_shape
elif len(config.model.observation_shape) == 3:
# image obs input, e.g. atari environments
self.zero_obs_shape = (
config.model.observation_shape[-2], config.model.observation_shape[-1], config.model.image_channel
self.obs_segment = []
self.action_segment = []
self.reward_segment = []
self.child_visit_segment = []
self.root_value_segment = []
self.action_mask_segment = []
self.to_play_segment = []
self.target_values = []
self.target_rewards = []
self.target_policies = []
self.improved_policy_probs = []
if self.sampled_algo:
self.root_sampled_actions = []
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []
def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray:
Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps].
- timestep (int): The time step.
- num_unroll_steps (int): The extra length of the observation frames.
- padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory.
stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps]
if padding:
pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_obs)
if pad_len > 0:
pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)])
stacked_obs = np.concatenate((stacked_obs, pad_frames))
if self.transform2string:
stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
return stacked_obs
def zero_obs(self) -> List:
Return an observation frame filled with zeros.
ndarray: An array filled with zeros.
return [np.zeros(self.zero_obs_shape, dtype=np.float32) for _ in range(self.frame_stack_num)]
def get_obs(self) -> List:
Return an observation in the correct format for model inference.
stacked_obs (List): An observation in the correct format for model inference.
timestep_obs = len(self.obs_segment) - self.frame_stack_num
timestep_reward = len(self.reward_segment)
assert timestep_obs == timestep_reward, "timestep_obs: {}, timestep_reward: {}".format(
timestep_obs, timestep_reward
timestep = timestep_reward
stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num]
if self.transform2string:
stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
return stacked_obs
def append(
action: np.ndarray,
obs: np.ndarray,
reward: np.ndarray,
action_mask: np.ndarray = None,
to_play: int = -1,
chance: int = 0,
) -> None:
Append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t}.
if self.use_ture_chance_label_in_chance_encoder:
def pad_over(
self, next_segment_observations: List, next_segment_rewards: List, next_segment_root_values: List,
next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None,
) -> None:
To make sure the correction of value targets, we need to add (o_t, r_t, etc) from the next game_segment
, which is necessary for the bootstrapped values at the end states of previous game_segment.
e.g: len = 100; target value v_100 = r_100 + gamma^1 r_101 + ... + gamma^4 r_104 + gamma^5 v_105,
but r_101, r_102, ... are from the next game_segment.
- next_segment_observations (:obj:`list`): o_t from the next game_segment
- next_segment_rewards (:obj:`list`): r_t from the next game_segment
- next_segment_root_values (:obj:`list`): root values of MCTS from the next game_segment
- next_segment_child_visits (:obj:`list`): root visit count distributions of MCTS from the next game_segment
- next_segment_improved_policy (:obj:`list`): root children select policy of MCTS from the next game_segment (Only used in Gumbel MuZero)
assert len(next_segment_observations) <= self.num_unroll_steps
assert len(next_segment_child_visits) <= self.num_unroll_steps
assert len(next_segment_root_values) <= self.num_unroll_steps + self.td_steps
assert len(next_segment_rewards) <= self.num_unroll_steps + self.td_steps - 1
# ==============================================================
# The core difference between GumbelMuZero and MuZero
# ==============================================================
if self.gumbel_algo:
assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.td_steps
# NOTE: next block observation should start from (stacked_observation - 1) in next trajectory
for observation in next_segment_observations:
for reward in next_segment_rewards:
for value in next_segment_root_values:
for child_visits in next_segment_child_visits:
if self.gumbel_algo:
for improved_policy in next_segment_improved_policy:
if self.use_ture_chance_label_in_chance_encoder:
for chances in next_chances:
def get_targets(self, timestep: int) -> Tuple:
return the value/reward/policy targets at step timestep
return self.target_values[timestep], self.target_rewards[timestep], self.target_policies[timestep]
def store_search_stats(
self, visit_counts: List, root_value: List, root_sampled_actions: List = None, improved_policy: List = None, idx: int = None
) -> None:
store the visit count distributions and value of the root node after MCTS.
sum_visits = sum(visit_counts)
if idx is None:
self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts])
if self.sampled_algo:
# store the improved policy in Gumbel Muzero: \pi'=softmax(logits + \sigma(CompletedQ))
if self.gumbel_algo:
self.child_visit_segment[idx] = [visit_count / sum_visits for visit_count in visit_counts]
self.root_value_segment[idx] = root_value
self.improved_policy_probs[idx] = improved_policy
def game_segment_to_array(self) -> None:
Post-process the data when a `GameSegment` block is full. This function converts various game segment
elements into numpy arrays for easier manipulation and processing.
The structure and shapes of different game segment elements are as follows. Let's assume
`game_segment_length`=20, `stack`=4, `num_unroll_steps`=5, `td_steps`=5:
- obs: game_segment_length + stack + num_unroll_steps, 20+4+5
- action: game_segment_length -> 20
- reward: game_segment_length + num_unroll_steps + td_steps -1 20+5+5-1
- root_values: game_segment_length + num_unroll_steps + td_steps -> 20+5+5
- child_visits: game_segment_length + num_unroll_steps -> 20+5
- to_play: game_segment_length -> 20
- action_mask: game_segment_length -> 20
Here is an illustration of the structure of `obs` and `rew` for two consecutive game segments
(game_segment_i and game_segment_i+1):
- game_segment_i (obs): 4 20 5
- game_segment_i+1 (obs): 4 20 5
- game_segment_i (rew): 20 5 4
- game_segment_i+1 (rew): 20 5 4
- self.obs_segment (:obj:`numpy.ndarray`): A numpy array version of the original obs_segment.
- self.action_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_segment.
- self.reward_segment (:obj:`numpy.ndarray`): A numpy array version of the original reward_segment.
- self.child_visit_segment (:obj:`numpy.ndarray`): A numpy array version of the original child_visit_segment.
- self.root_value_segment (:obj:`numpy.ndarray`): A numpy array version of the original root_value_segment.
- self.improved_policy_probs (:obj:`numpy.ndarray`): A numpy array version of the original improved_policy_probs.
- self.action_mask_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_mask_segment.
- self.to_play_segment (:obj:`numpy.ndarray`): A numpy array version of the original to_play_segment.
- self.chance_segment (:obj:`numpy.ndarray`, optional): A numpy array version of the original chance_segment. Only
created if `self.use_ture_chance_label_in_chance_encoder` is True.
.. note::
For environments with a variable action space, such as board games, the elements in `child_visit_segment` may have
different lengths. In such scenarios, it is necessary to use the object data type for `self.child_visit_segment`.
self.obs_segment = np.array(self.obs_segment)
self.action_segment = np.array(self.action_segment)
self.reward_segment = np.array(self.reward_segment)
# Check if all elements in self.child_visit_segment have the same length
if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment):
self.child_visit_segment = np.array(self.child_visit_segment)
# In the case of environments with a variable action space, such as board games,
# the elements in child_visit_segment may have different lengths.
# In such scenarios, it is necessary to use the object data type.
self.child_visit_segment = np.array(self.child_visit_segment, dtype=object)
self.root_value_segment = np.array(self.root_value_segment)
self.improved_policy_probs = np.array(self.improved_policy_probs)
self.action_mask_segment = np.array(self.action_mask_segment)
self.to_play_segment = np.array(self.to_play_segment)
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = np.array(self.chance_segment)
def reset(self, init_observations: np.ndarray) -> None:
Initialize the game segment using ``init_observations``,
which is the previous ``frame_stack_num`` stacked frames.
- init_observations (:obj:`list`): list of the stack observations in the previous time steps.
self.obs_segment = []
self.action_segment = []
self.reward_segment = []
self.child_visit_segment = []
self.root_value_segment = []
self.action_mask_segment = []
self.to_play_segment = []
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []
assert len(init_observations) == self.frame_stack_num
for observation in init_observations:
def is_full(self) -> bool:
Check whether the current game segment is full, i.e. larger than the segment length.
bool: True if the game segment is full, False otherwise.
return len(self.action_segment) >= self.game_segment_length
def legal_actions(self):
return [_ for _ in range(self.action_space.n)]
def __len__(self):
return len(self.action_segment)