|
""" |
|
The Node, Roots class and related core functions for EfficientZero. |
|
""" |
|
import math |
|
import random |
|
from typing import List, Any, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from .minimax import MinMaxStats |
|
|
|
|
|
class Node: |
|
""" |
|
Overview: |
|
the node base class for EfficientZero. |
|
""" |
|
|
|
def __init__(self, prior: float, legal_actions: List = None, action_space_size: int = 9) -> None: |
|
self.prior = prior |
|
self.legal_actions = legal_actions |
|
self.action_space_size = action_space_size |
|
|
|
self.is_reset = 0 |
|
self.visit_count = 0 |
|
self.value_sum = 0 |
|
self.best_action = -1 |
|
self.to_play = -1 |
|
self.value_prefix = 0.0 |
|
self.children = {} |
|
self.children_index = [] |
|
self.simulation_index = 0 |
|
self.batch_index = 0 |
|
self.parent_value_prefix = 0 |
|
|
|
def expand( |
|
self, to_play: int, simulation_index: int, batch_index: int, value_prefix: float, policy_logits: List[float] |
|
) -> None: |
|
""" |
|
Overview: |
|
Expand the child nodes of the current node. |
|
Arguments: |
|
- to_play (:obj:`Class int`): which player to play the game in the current node. |
|
- simulation_index (:obj:`Class int`): the x/first index of hidden state vector of the current node, i.e. the search depth. |
|
- batch_index (:obj:`Class int`): the y/second index of hidden state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. |
|
- value_prefix: (:obj:`Class float`): the value prefix of the current node. |
|
- policy_logits: (:obj:`Class List`): the policy logit of the child nodes. |
|
""" |
|
self.to_play = to_play |
|
if self.legal_actions is None: |
|
self.legal_actions = np.arange(len(policy_logits)) |
|
|
|
self.simulation_index = simulation_index |
|
self.batch_index = batch_index |
|
self.value_prefix = value_prefix |
|
|
|
policy_values = torch.softmax(torch.tensor([policy_logits[a] for a in self.legal_actions]), dim=0).tolist() |
|
policy = {a: policy_values[i] for i, a in enumerate(self.legal_actions)} |
|
for action, p in policy.items(): |
|
self.children[action] = Node(p, action_space_size=self.action_space_size) |
|
|
|
def add_exploration_noise(self, exploration_fraction: float, noises: List[float]) -> None: |
|
""" |
|
Overview: |
|
Add a noise to the prior of the child nodes. |
|
Arguments: |
|
- exploration_fraction: the fraction to add noise. |
|
- noises (:obj: list): the vector of noises added to each child node. length is len(self.legal_actions) |
|
""" |
|
for i, a in enumerate(self.legal_actions): |
|
""" |
|
i in index, a is action, e.g. self.legal_actions = [0,1,2,4,6,8], i=[0,1,2,3,4,5], a=[0,1,2,4,6,8] |
|
""" |
|
try: |
|
noise = noises[i] |
|
except Exception as error: |
|
print(error) |
|
child = self.get_child(a) |
|
prior = child.prior |
|
child.prior = prior * (1 - exploration_fraction) + noise * exploration_fraction |
|
|
|
def compute_mean_q(self, is_root: bool, parent_q: float, discount_factor: float) -> float: |
|
""" |
|
Overview: |
|
Compute the mean q value of the current node. |
|
Arguments: |
|
- is_root (:obj:`bool`): whether the current node is a root node. |
|
- parent_q (:obj:`float`): the q value of the parent node. |
|
- discount_factor (:obj:`float`): the discount_factor of reward. |
|
""" |
|
total_unsigned_q = 0.0 |
|
total_visits = 0 |
|
parent_value_prefix = self.value_prefix |
|
for a in self.legal_actions: |
|
child = self.get_child(a) |
|
if child.visit_count > 0: |
|
true_reward = child.value_prefix - parent_value_prefix |
|
if self.is_reset == 1: |
|
|
|
true_reward = child.value_prefix |
|
|
|
q_of_s_a = true_reward + discount_factor * child.value |
|
total_unsigned_q += q_of_s_a |
|
total_visits += 1 |
|
if is_root and total_visits > 0: |
|
mean_q = total_unsigned_q / total_visits |
|
else: |
|
|
|
|
|
mean_q = (parent_q + total_unsigned_q) / (total_visits + 1) |
|
return mean_q |
|
|
|
def print_out(self) -> None: |
|
pass |
|
|
|
def get_trajectory(self) -> List[Union[int, float]]: |
|
""" |
|
Overview: |
|
Find the current best trajectory starts from the current node. |
|
Outputs: |
|
- traj: a vector of node index, which is the current best trajectory from this node. |
|
""" |
|
|
|
traj = [] |
|
node = self |
|
best_action = node.best_action |
|
while best_action >= 0: |
|
traj.append(best_action) |
|
|
|
node = node.get_child(best_action) |
|
best_action = node.best_action |
|
return traj |
|
|
|
def get_children_distribution(self) -> List[Union[int, float]]: |
|
if self.legal_actions == []: |
|
return None |
|
distribution = {a: 0 for a in self.legal_actions} |
|
if self.expanded: |
|
for a in self.legal_actions: |
|
child = self.get_child(a) |
|
distribution[a] = child.visit_count |
|
|
|
distribution = [v for k, v in distribution.items()] |
|
return distribution |
|
|
|
def get_child(self, action: Union[int, float]) -> "Node": |
|
""" |
|
Overview: |
|
get children node according to the input action. |
|
""" |
|
if not isinstance(action, np.int64): |
|
action = int(action) |
|
return self.children[action] |
|
|
|
@property |
|
def expanded(self) -> bool: |
|
return len(self.children) > 0 |
|
|
|
@property |
|
def value(self) -> float: |
|
""" |
|
Overview: |
|
Return the estimated value of the current root node. |
|
""" |
|
if self.visit_count == 0: |
|
return 0 |
|
else: |
|
return self.value_sum / self.visit_count |
|
|
|
|
|
class Roots: |
|
|
|
def __init__(self, root_num: int, legal_actions_list: List) -> None: |
|
self.num = root_num |
|
self.root_num = root_num |
|
self.legal_actions_list = legal_actions_list |
|
|
|
self.roots = [] |
|
for i in range(self.root_num): |
|
if isinstance(legal_actions_list, list): |
|
self.action_space_size = len(legal_actions_list[i]) |
|
|
|
self.roots.append(Node(0, legal_actions_list[i], action_space_size=self.action_space_size)) |
|
else: |
|
|
|
self.action_space_size = legal_actions_list |
|
|
|
self.roots.append(Node(0, np.arange(legal_actions_list), action_space_size=self.action_space_size)) |
|
|
|
def prepare( |
|
self, |
|
root_noise_weight: float, |
|
noises: List[float], |
|
value_prefixs: List[float], |
|
policies: List[List[float]], |
|
to_play: int = -1 |
|
) -> None: |
|
""" |
|
Overview: |
|
Expand the roots and add noises. |
|
Arguments: |
|
- root_noise_weight: the exploration fraction of roots |
|
- noises: the vector of noise add to the roots. |
|
- value_prefixs: the vector of value prefixs of each root. |
|
- policies: the vector of policy logits of each root. |
|
- to_play_batch: the vector of the player side of each root. |
|
""" |
|
for i in range(self.root_num): |
|
|
|
if to_play in [-1, None]: |
|
self.roots[i].expand(-1, 0, i, value_prefixs[i], policies[i]) |
|
elif to_play is [None]: |
|
print('debug') |
|
else: |
|
self.roots[i].expand(to_play[i], 0, i, value_prefixs[i], policies[i]) |
|
|
|
self.roots[i].add_exploration_noise(root_noise_weight, noises[i]) |
|
self.roots[i].visit_count += 1 |
|
|
|
def prepare_no_noise(self, value_prefixs: List[float], policies: List[List[float]], to_play: int = -1) -> None: |
|
""" |
|
Overview: |
|
Expand the roots without noise. |
|
Arguments: |
|
- value_prefixs: the vector of value prefixs of each root. |
|
- policies: the vector of policy logits of each root. |
|
- to_play_batch: the vector of the player side of each root. |
|
""" |
|
for i in range(self.root_num): |
|
if to_play in [-1, None]: |
|
self.roots[i].expand(-1, 0, i, value_prefixs[i], policies[i]) |
|
else: |
|
self.roots[i].expand(to_play[i], 0, i, value_prefixs[i], policies[i]) |
|
|
|
self.roots[i].visit_count += 1 |
|
|
|
def clear(self) -> None: |
|
self.roots.clear() |
|
|
|
def get_trajectories(self) -> List[List[Union[int, float]]]: |
|
""" |
|
Overview: |
|
Find the current best trajectory starts from each root. |
|
Outputs: |
|
- traj: a vector of node index, which is the current best trajectory from each root. |
|
""" |
|
trajs = [] |
|
for i in range(self.root_num): |
|
trajs.append(self.roots[i].get_trajectory()) |
|
return trajs |
|
|
|
def get_distributions(self) -> List[List[Union[int, float]]]: |
|
""" |
|
Overview: |
|
Get the children distribution of each root. |
|
Outputs: |
|
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). |
|
""" |
|
distributions = [] |
|
for i in range(self.root_num): |
|
distributions.append(self.roots[i].get_children_distribution()) |
|
|
|
return distributions |
|
|
|
def get_values(self) -> List[float]: |
|
""" |
|
Overview: |
|
Return the estimated value of each root. |
|
""" |
|
values = [] |
|
for i in range(self.root_num): |
|
values.append(self.roots[i].value) |
|
return values |
|
|
|
|
|
class SearchResults: |
|
|
|
def __init__(self, num: int) -> None: |
|
self.num = num |
|
self.nodes = [] |
|
self.search_paths = [] |
|
self.latent_state_index_in_search_path = [] |
|
self.latent_state_index_in_batch = [] |
|
self.last_actions = [] |
|
self.search_lens = [] |
|
|
|
|
|
def select_child( |
|
root: Node, min_max_stats: MinMaxStats, pb_c_base: float, pb_c_int: float, discount_factor: float, |
|
mean_q: float, players: int |
|
) -> Union[int, float]: |
|
""" |
|
Overview: |
|
Select the child node of the roots according to ucb scores. |
|
Arguments: |
|
- root: the roots to select the child node. |
|
- min_max_stats (:obj:`Class MinMaxStats`): a tool used to min-max normalize the score. |
|
- pb_c_base (:obj:`Class Float`): constant c1 used in pUCT rule, typically 1.25. |
|
- pb_c_int (:obj:`Class Float`): constant c2 used in pUCT rule, typically 19652. |
|
- discount_factor (:obj:`Class Float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. |
|
- mean_q (:obj:`Class Float`): the mean q value of the parent node. |
|
- players (:obj:`Class Int`): the number of players. one/in self-play-mode board games. |
|
Returns: |
|
- action (:obj:`Union[int, float]`): Choose the action with the highest ucb score. |
|
""" |
|
max_score = -np.inf |
|
epsilon = 0.000001 |
|
max_index_lst = [] |
|
for a in root.legal_actions: |
|
child = root.get_child(a) |
|
temp_score = compute_ucb_score( |
|
child, min_max_stats, mean_q, root.is_reset, root.visit_count, root.value_prefix, pb_c_base, pb_c_int, |
|
discount_factor, players |
|
) |
|
if max_score < temp_score: |
|
max_score = temp_score |
|
max_index_lst.clear() |
|
max_index_lst.append(a) |
|
elif temp_score >= max_score - epsilon: |
|
|
|
max_index_lst.append(a) |
|
|
|
action = 0 |
|
if len(max_index_lst) > 0: |
|
action = random.choice(max_index_lst) |
|
return action |
|
|
|
|
|
def compute_ucb_score( |
|
child: Node, |
|
min_max_stats: MinMaxStats, |
|
parent_mean_q: float, |
|
is_reset: int, |
|
total_children_visit_counts: float, |
|
parent_value_prefix: float, |
|
pb_c_base: float, |
|
pb_c_init: float, |
|
discount_factor: float, |
|
players: int = 1, |
|
) -> float: |
|
""" |
|
Overview: |
|
Compute the ucb score of the child. |
|
Arguments: |
|
- child: the child node to compute ucb score. |
|
- min_max_stats: a tool used to min-max normalize the score. |
|
- parent_mean_q: the mean q value of the parent node. |
|
- is_reset: whether the value prefix needs to be reset. |
|
- total_children_visit_counts: the total visit counts of the child nodes of the parent node. |
|
- parent_value_prefix: the value prefix of parent node. |
|
- pb_c_base: constants c2 in muzero. |
|
- pb_c_init: constants c1 in muzero. |
|
- disount_factor: the discount factor of reward. |
|
- players: the number of players. |
|
Outputs: |
|
- ucb_value: the ucb score of the child. |
|
""" |
|
pb_c = math.log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init |
|
pb_c *= (math.sqrt(total_children_visit_counts) / (child.visit_count + 1)) |
|
|
|
prior_score = pb_c * child.prior |
|
if child.visit_count == 0: |
|
value_score = parent_mean_q |
|
else: |
|
true_reward = child.value_prefix - parent_value_prefix |
|
if is_reset == 1: |
|
true_reward = child.value_prefix |
|
if players == 1: |
|
value_score = true_reward + discount_factor * child.value |
|
elif players == 2: |
|
value_score = true_reward + discount_factor * (-child.value) |
|
|
|
value_score = min_max_stats.normalize(value_score) |
|
if value_score < 0: |
|
value_score = 0 |
|
if value_score > 1: |
|
value_score = 1 |
|
ucb_score = prior_score + value_score |
|
|
|
return ucb_score |
|
|
|
|
|
def batch_traverse( |
|
roots: Any, |
|
pb_c_base: float, |
|
pb_c_init: float, |
|
discount_factor: float, |
|
min_max_stats_lst: List[MinMaxStats], |
|
results: SearchResults, |
|
virtual_to_play: List, |
|
) -> Tuple[List[None], List[None], List[None], Union[list, int]]: |
|
""" |
|
Overview: |
|
traverse, also called expansion. process a batch roots parallely. |
|
Arguments: |
|
- roots (:obj:`Any`): a batch of root nodes to be expanded. |
|
- pb_c_base (:obj:`float`): constant c1 used in pUCT rule, typically 1.25. |
|
- pb_c_init (:obj:`float`): constant c2 used in pUCT rule, typically 19652. |
|
- discount_factor (:obj:`float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. |
|
- virtual_to_play (:obj:`list`): the to_play list used in self_play collecting and training in board games, |
|
`virtual` is to emphasize that actions are performed on an imaginary hidden state. |
|
Returns: |
|
- latent_state_index_in_search_path (:obj:`list`): the list of x/first index of hidden state vector of the searched node, i.e. the search depth. |
|
- latent_state_index_in_batch (:obj:`list`): the list of y/second index of hidden state vector of the searched node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. |
|
- last_actions (:obj:`list`): the action performed by the previous node. |
|
- virtual_to_play (:obj:`list`): the to_play list used in self_play collecting and trainin gin board games, |
|
`virtual` is to emphasize that actions are performed on an imaginary hidden state. |
|
""" |
|
parent_q = 0.0 |
|
results.search_lens = [None for _ in range(results.num)] |
|
results.last_actions = [None for _ in range(results.num)] |
|
results.nodes = [None for _ in range(results.num)] |
|
results.latent_state_index_in_search_path = [None for _ in range(results.num)] |
|
results.latent_state_index_in_batch = [None for _ in range(results.num)] |
|
results.search_paths = {i: [] for i in range(results.num)} |
|
|
|
if isinstance(virtual_to_play, int): |
|
if virtual_to_play in [1, 2]: |
|
players = 2 |
|
elif virtual_to_play in [-1, None]: |
|
players = 1 |
|
elif isinstance(virtual_to_play, list): |
|
if virtual_to_play[0] in [1, 2]: |
|
players = 2 |
|
elif virtual_to_play[0] in [-1, None]: |
|
players = 1 |
|
|
|
for i in range(results.num): |
|
node = roots.roots[i] |
|
is_root = 1 |
|
search_len = 0 |
|
results.search_paths[i].append(node) |
|
""" |
|
MCTS stage 1: Selection |
|
Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l. |
|
The leaf node is the node that is currently not expanded. |
|
""" |
|
while node.expanded: |
|
|
|
mean_q = node.compute_mean_q(is_root, parent_q, discount_factor) |
|
is_root = 0 |
|
parent_q = mean_q |
|
|
|
|
|
action = select_child( |
|
node, min_max_stats_lst.stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players |
|
) |
|
if players == 2: |
|
|
|
if virtual_to_play[i] == 1: |
|
virtual_to_play[i] = 2 |
|
else: |
|
virtual_to_play[i] = 1 |
|
node.best_action = action |
|
|
|
|
|
node = node.get_child(action) |
|
last_action = action |
|
results.search_paths[i].append(node) |
|
search_len += 1 |
|
|
|
|
|
parent = results.search_paths[i][len(results.search_paths[i]) - 1 - 1] |
|
|
|
results.latent_state_index_in_search_path[i] = parent.simulation_index |
|
results.latent_state_index_in_batch[i] = parent.batch_index |
|
results.last_actions[i] = last_action |
|
results.search_lens[i] = search_len |
|
|
|
results.nodes[i] = node |
|
|
|
return results.latent_state_index_in_search_path, results.latent_state_index_in_batch, results.last_actions, virtual_to_play |
|
|
|
|
|
def backpropagate( |
|
search_path: List[Node], min_max_stats: MinMaxStats, to_play: int, value: float, discount_factor: float |
|
) -> None: |
|
""" |
|
Overview: |
|
Update the value sum and visit count of nodes along the search path. |
|
Arguments: |
|
- search_path: a vector of nodes on the search path. |
|
- min_max_stats: a tool used to min-max normalize the q value. |
|
- to_play: which player to play the game in the current node. |
|
- value: the value to propagate along the search path. |
|
- discount_factor: the discount factor of reward. |
|
""" |
|
assert to_play is None or to_play in [-1, 1, 2], f'to_play is {to_play}!' |
|
if to_play is None or to_play == -1: |
|
|
|
bootstrap_value = value |
|
path_len = len(search_path) |
|
for i in range(path_len - 1, -1, -1): |
|
node = search_path[i] |
|
node.value_sum += bootstrap_value |
|
node.visit_count += 1 |
|
|
|
parent_value_prefix = 0.0 |
|
is_reset = 0 |
|
if i >= 1: |
|
parent = search_path[i - 1] |
|
parent_value_prefix = parent.value_prefix |
|
is_reset = parent.is_reset |
|
|
|
true_reward = node.value_prefix - parent_value_prefix |
|
min_max_stats.update(true_reward + discount_factor * node.value) |
|
if is_reset == 1: |
|
true_reward = node.value_prefix |
|
bootstrap_value = true_reward + discount_factor * bootstrap_value |
|
else: |
|
|
|
bootstrap_value = value |
|
path_len = len(search_path) |
|
for i in range(path_len - 1, -1, -1): |
|
node = search_path[i] |
|
|
|
node.value_sum += bootstrap_value if node.to_play == to_play else -bootstrap_value |
|
|
|
node.visit_count += 1 |
|
parent_value_prefix = 0.0 |
|
is_reset = 0 |
|
if i >= 1: |
|
parent = search_path[i - 1] |
|
parent_value_prefix = parent.value_prefix |
|
is_reset = parent.is_reset |
|
|
|
|
|
|
|
true_reward = node.value_prefix - parent_value_prefix |
|
|
|
if is_reset == 1: |
|
true_reward = node.value_prefix |
|
|
|
min_max_stats.update(true_reward + discount_factor * -node.value) |
|
|
|
|
|
bootstrap_value = ( |
|
-true_reward if node.to_play == to_play else true_reward |
|
) + discount_factor * bootstrap_value |
|
|
|
|
|
def batch_backpropagate( |
|
simulation_index: int, |
|
discount_factor: float, |
|
value_prefixs: List, |
|
values: List[float], |
|
policies: List[float], |
|
min_max_stats_lst: List[MinMaxStats], |
|
results: SearchResults, |
|
is_reset_list: List, |
|
to_play: list = None, |
|
) -> None: |
|
""" |
|
Overview: |
|
Backpropagation along the search path to update the attributes. |
|
Arguments: |
|
- simulation_index (:obj:`Class Int`): The index of latent state of the leaf node in the search path. |
|
- discount_factor (:obj:`Class Float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. |
|
- value_prefixs (:obj:`Class List`): the value prefixs of nodes along the search path. |
|
- values (:obj:`Class List`): the values to propagate along the search path. |
|
- policies (:obj:`Class List`): the policy logits of nodes along the search path. |
|
- min_max_stats_lst (:obj:`Class List[MinMaxStats]`): a tool used to min-max normalize the q value. |
|
- results (:obj:`Class List`): the search results. |
|
- is_reset_list (:obj:`Class List`): the vector of is_reset nodes along the search path, where is_reset represents for whether the parent value prefix needs to be reset. |
|
- to_play (:obj:`Class List`): the batch of which player is playing on this node. |
|
""" |
|
for i in range(results.num): |
|
|
|
if to_play in [-1, None]: |
|
|
|
results.nodes[i].expand(-1, simulation_index, i, value_prefixs[i], policies[i]) |
|
else: |
|
results.nodes[i].expand(to_play[i], simulation_index, i, value_prefixs[i], policies[i]) |
|
|
|
|
|
results.nodes[i].is_reset = is_reset_list[i] |
|
|
|
|
|
if to_play in [-1, None]: |
|
backpropagate(results.search_paths[i], min_max_stats_lst.stats_lst[i], -1, values[i], discount_factor) |
|
else: |
|
backpropagate( |
|
results.search_paths[i], min_max_stats_lst.stats_lst[i], to_play[i], values[i], discount_factor |
|
) |
|
|