File size: 6,877 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import os
from dataclasses import dataclass
from typing import Any
import numpy as np
from graphviz import Digraph
def generate_random_actions_discrete(num_actions: int, action_space_size: int, num_of_sampled_actions: int,
reshape=False):
"""
Overview:
Generate a list of random actions.
Arguments:
- num_actions (:obj:`int`): The number of actions to generate.
- action_space_size (:obj:`int`): The size of the action space.
- num_of_sampled_actions (:obj:`int`): The number of sampled actions.
- reshape (:obj:`bool`): Whether to reshape the actions.
Returns:
A list of random actions.
"""
actions = [
np.random.randint(0, action_space_size, num_of_sampled_actions).reshape(-1)
for _ in range(num_actions)
]
# If num_of_sampled_actions == 1, flatten the actions to a list of numbers
if num_of_sampled_actions == 1:
actions = [action[0] for action in actions]
# Reshape actions if needed
if reshape and num_of_sampled_actions > 1:
actions = [action.reshape(num_of_sampled_actions, 1) for action in actions]
return actions
@dataclass
class BufferedData:
data: Any
index: str
meta: dict
def get_augmented_data(board_size, play_data):
"""
Overview:
augment the data set by rotation and flipping
Arguments:
play_data: [(state, mcts_prob, winner_z), ..., ...]
"""
extend_data = []
for data in play_data:
state = data['state']
mcts_prob = data['mcts_prob']
winner = data['winner']
for i in [1, 2, 3, 4]:
# rotate counterclockwise
equi_state = np.array([np.rot90(s, i) for s in state])
equi_mcts_prob = np.rot90(np.flipud(mcts_prob.reshape(board_size, board_size)), i)
extend_data.append(
{
'state': equi_state,
'mcts_prob': np.flipud(equi_mcts_prob).flatten(),
'winner': winner
}
)
# flip horizontally
equi_state = np.array([np.fliplr(s) for s in equi_state])
equi_mcts_prob = np.fliplr(equi_mcts_prob)
extend_data.append(
{
'state': equi_state,
'mcts_prob': np.flipud(equi_mcts_prob).flatten(),
'winner': winner
}
)
return extend_data
def prepare_observation(observation_list, model_type='conv'):
"""
Overview:
Prepare the observations to satisfy the input format of model.
if model_type='conv':
[B, S, W, H, C] -> [B, S x C, W, H]
where B is batch size, S is stack num, W is width, H is height, and C is the number of channels
if model_type='mlp':
[B, S, O] -> [B, S x O]
where B is batch size, S is stack num, O is obs shape.
Arguments:
- observation_list (:obj:`List`): list of observations.
- model_type (:obj:`str`): type of the model. (default is 'conv')
"""
assert model_type in ['conv', 'mlp']
observation_array = np.array(observation_list)
if model_type == 'conv':
# for 3-dimensional image obs
if len(observation_array.shape) == 3:
# for vector obs input, e.g. classical control and box2d environments
# to be compatible with LightZero model/policy,
# observation_array: [B, S, O], where O is original obs shape
# [B, S, O] -> [B, S, O, 1]
observation_array = observation_array.reshape(
observation_array.shape[0], observation_array.shape[1], observation_array.shape[2], 1
)
elif len(observation_array.shape) == 5:
# image obs input, e.g. atari environments
# observation_array: [B, S, W, H, C]
# 1, 4, 8, 1, 1 -> 1, 4, 1, 8, 1
# [B, S, W, H, C] -> [B, S, C, W, H]
observation_array = np.transpose(observation_array, (0, 1, 4, 2, 3))
shape = observation_array.shape
# 1, 4, 1, 8, 1 -> 1, 4*1, 8, 1
# [B, S, C, W, H] -> [B, S*C, W, H]
observation_array = observation_array.reshape((shape[0], -1, shape[-2], shape[-1]))
elif model_type == 'mlp':
# for 1-dimensional vector obs
# observation_array: [B, S, O], where O is original obs shape
# [B, S, O] -> [B, S*O]
# print(observation_array.shape)
observation_array = observation_array.reshape(observation_array.shape[0], -1)
# print(observation_array.shape)
return observation_array
def obtain_tree_topology(root, to_play=-1):
node_stack = []
edge_topology_list = []
node_topology_list = []
node_id_list = []
node_stack.append(root)
while len(node_stack) > 0:
node = node_stack[-1]
node_stack.pop()
node_dict = {}
node_dict['node_id'] = node.simulation_index
node_dict['visit_count'] = node.visit_count
node_dict['policy_prior'] = node.prior
node_dict['value'] = node.value
node_topology_list.append(node_dict)
node_id_list.append(node.simulation_index)
for a in node.legal_actions:
child = node.get_child(a)
if child.expanded:
child.parent_simulation_index = node.simulation_index
edge_dict = {}
edge_dict['parent_id'] = node.simulation_index
edge_dict['child_id'] = child.simulation_index
edge_topology_list.append(edge_dict)
node_stack.append(child)
return edge_topology_list, node_id_list, node_topology_list
def plot_simulation_graph(env_root, current_step, graph_directory=None):
edge_topology_list, node_id_list, node_topology_list = obtain_tree_topology(env_root)
dot = Digraph(comment='this is direction')
for node_topology in node_topology_list:
node_name = str(node_topology['node_id'])
label = f"node_id: {node_topology['node_id']}, \n visit_count: {node_topology['visit_count']}, \n policy_prior: {round(node_topology['policy_prior'], 4)}, \n value: {round(node_topology['value'], 4)}"
dot.node(node_name, label=label)
for edge_topology in edge_topology_list:
parent_id = str(edge_topology['parent_id'])
child_id = str(edge_topology['child_id'])
label = parent_id + '-' + child_id
dot.edge(parent_id, child_id, label=label)
if graph_directory is None:
graph_directory = './data_visualize/'
if not os.path.exists(graph_directory):
os.makedirs(graph_directory)
graph_path = graph_directory + 'simulation_visualize_' + str(current_step) + 'step.gv'
dot.format = 'png'
dot.render(graph_path, view=False)
|