|
import copy |
|
from collections import namedtuple |
|
from typing import List, Dict, Tuple |
|
|
|
import numpy as np |
|
import torch.distributions |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
from ding.policy.base_policy import Policy |
|
from ding.torch_utils import to_device |
|
from ding.utils import POLICY_REGISTRY |
|
from ding.utils.data import default_collate |
|
from easydict import EasyDict |
|
|
|
from lzero.policy import configure_optimizers |
|
from lzero.policy.utils import pad_and_get_lengths, compute_entropy |
|
|
|
|
|
@POLICY_REGISTRY.register('sampled_alphazero') |
|
class SampledAlphaZeroPolicy(Policy): |
|
""" |
|
Overview: |
|
The policy class for Sampled AlphaZero. |
|
""" |
|
|
|
|
|
config = dict( |
|
|
|
type='alphazero', |
|
|
|
|
|
sampled_algo=False, |
|
normalize_prob_of_sampled_actions=False, |
|
policy_loss_type='cross_entropy', |
|
|
|
torch_compile=False, |
|
|
|
tensor_float_32=False, |
|
model=dict( |
|
|
|
observation_shape=(3, 6, 6), |
|
|
|
num_res_blocks=1, |
|
|
|
num_channels=32, |
|
), |
|
|
|
mcts_ctree=True, |
|
|
|
cuda=False, |
|
|
|
|
|
|
|
|
|
|
|
|
|
update_per_collect=None, |
|
|
|
model_update_ratio=0.1, |
|
|
|
batch_size=256, |
|
|
|
optim_type='SGD', |
|
|
|
learning_rate=0.2, |
|
|
|
weight_decay=1e-4, |
|
|
|
momentum=0.9, |
|
|
|
grad_clip_value=10, |
|
|
|
value_weight=1.0, |
|
|
|
collector_env_num=8, |
|
|
|
evaluator_env_num=3, |
|
|
|
|
|
lr_piecewise_constant_decay=True, |
|
|
|
threshold_training_steps_for_final_lr=int(5e5), |
|
|
|
|
|
manual_temperature_decay=False, |
|
|
|
threshold_training_steps_for_final_temperature=int(1e5), |
|
|
|
|
|
fixed_temperature_value=0.25, |
|
mcts=dict( |
|
|
|
num_simulations=50, |
|
|
|
max_moves=512, |
|
|
|
root_dirichlet_alpha=0.3, |
|
|
|
root_noise_weight=0.25, |
|
|
|
pb_c_base=19652, |
|
|
|
pb_c_init=1.25, |
|
|
|
legal_actions=None, |
|
|
|
action_space_size=9, |
|
|
|
num_of_sampled_actions=2, |
|
|
|
continuous_action_space=False, |
|
), |
|
other=dict(replay_buffer=dict( |
|
replay_buffer_size=int(1e6), |
|
save_episode=False, |
|
)), |
|
) |
|
|
|
def default_model(self) -> Tuple[str, List[str]]: |
|
""" |
|
Overview: |
|
Return this algorithm default model setting for demonstration. |
|
Returns: |
|
- model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. |
|
- import_names (:obj:`List[str]`): The model class path list used in this algorithm. |
|
""" |
|
return 'AlphaZeroModel', ['lzero.model.alphazero_model'] |
|
|
|
def _init_learn(self) -> None: |
|
assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type |
|
if self._cfg.optim_type == 'SGD': |
|
self._optimizer = optim.SGD( |
|
self._model.parameters(), |
|
lr=self._cfg.learning_rate, |
|
momentum=self._cfg.momentum, |
|
weight_decay=self._cfg.weight_decay, |
|
) |
|
elif self._cfg.optim_type == 'Adam': |
|
self._optimizer = optim.Adam( |
|
self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay |
|
) |
|
elif self._cfg.optim_type == 'AdamW': |
|
self._optimizer = configure_optimizers( |
|
model=self._model, |
|
weight_decay=self._cfg.weight_decay, |
|
learning_rate=self._cfg.learning_rate, |
|
device_type=self._cfg.device |
|
) |
|
|
|
if self._cfg.lr_piecewise_constant_decay: |
|
from torch.optim.lr_scheduler import LambdaLR |
|
max_step = self._cfg.threshold_training_steps_for_final_lr |
|
|
|
|
|
lr_lambda = lambda step: 1 if step < max_step * 0.33 else (0.1 if step < max_step * 0.66 else 0.01) |
|
|
|
self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) |
|
|
|
|
|
self._value_weight = self._cfg.value_weight |
|
self._entropy_weight = self._cfg.entropy_weight |
|
|
|
self._learn_model = self._model |
|
|
|
|
|
if self._cfg.torch_compile: |
|
self._learn_model = torch.compile(self._learn_model) |
|
|
|
def _forward_learn(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]: |
|
for input_dict in inputs: |
|
|
|
if 'katago_game_state' in input_dict['obs']: |
|
del input_dict['obs']['katago_game_state'] |
|
|
|
|
|
if 'katago_game_state' in input_dict['next_obs']: |
|
del input_dict['next_obs']['katago_game_state'] |
|
|
|
|
|
|
|
inputs = pad_and_get_lengths(inputs, self._cfg.mcts.num_of_sampled_actions) |
|
inputs = default_collate(inputs) |
|
valid_action_length = inputs['action_length'] |
|
|
|
if self._cuda: |
|
inputs = to_device(inputs, self._device) |
|
self._learn_model.train() |
|
|
|
state_batch = inputs['obs']['observation'] |
|
mcts_visit_count_probs = inputs['probs'] |
|
reward = inputs['reward'] |
|
root_sampled_actions = inputs['root_sampled_actions'] |
|
|
|
if len(root_sampled_actions.shape) == 1: |
|
print(f"root_sampled_actions.shape: {root_sampled_actions.shape}") |
|
state_batch = state_batch.to(device=self._device, dtype=torch.float) |
|
mcts_visit_count_probs = mcts_visit_count_probs.to(device=self._device, dtype=torch.float) |
|
reward = reward.to(device=self._device, dtype=torch.float) |
|
|
|
policy_probs, values = self._learn_model.compute_policy_value(state_batch) |
|
policy_log_probs = torch.log(policy_probs) |
|
|
|
|
|
entropy = compute_entropy(policy_probs) |
|
entropy_loss = -entropy |
|
|
|
|
|
|
|
|
|
policy_loss = self._calculate_policy_loss_disc(policy_probs, mcts_visit_count_probs, root_sampled_actions, |
|
valid_action_length) |
|
|
|
|
|
|
|
|
|
value_loss = F.mse_loss(values.view(-1), reward) |
|
|
|
total_loss = self._value_weight * value_loss + policy_loss + self._entropy_weight * entropy_loss |
|
self._optimizer.zero_grad() |
|
total_loss.backward() |
|
|
|
total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( |
|
list(self._model.parameters()), |
|
max_norm=self._cfg.grad_clip_value, |
|
) |
|
self._optimizer.step() |
|
if self._cfg.lr_piecewise_constant_decay is True: |
|
self.lr_scheduler.step() |
|
|
|
|
|
|
|
|
|
return { |
|
'cur_lr': self._optimizer.param_groups[0]['lr'], |
|
'total_loss': total_loss.item(), |
|
'policy_loss': policy_loss.item(), |
|
'value_loss': value_loss.item(), |
|
'entropy_loss': entropy_loss.item(), |
|
'total_grad_norm_before_clip': total_grad_norm_before_clip.item(), |
|
'collect_mcts_temperature': self.collect_mcts_temperature, |
|
} |
|
|
|
def _calculate_policy_loss_disc( |
|
self, policy_probs: torch.Tensor, target_policy: torch.Tensor, |
|
target_sampled_actions: torch.Tensor, valid_action_lengths: torch.Tensor |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
sampled_policy_probs = policy_probs.gather(1, target_sampled_actions) |
|
sampled_target_policy = target_policy.gather(1, target_sampled_actions) |
|
|
|
|
|
max_length = target_sampled_actions.size(1) |
|
mask = torch.arange(max_length).expand(len(valid_action_lengths), max_length) < valid_action_lengths.unsqueeze( |
|
1) |
|
mask = mask.to(device=self._device) |
|
|
|
|
|
sampled_policy_probs = sampled_policy_probs * mask.float() |
|
sampled_target_policy = sampled_target_policy * mask.float() |
|
|
|
|
|
sampled_policy_probs = sampled_policy_probs / (sampled_policy_probs.sum(dim=1, keepdim=True) + 1e-6) |
|
sampled_target_policy = sampled_target_policy / (sampled_target_policy.sum(dim=1, keepdim=True) + 1e-6) |
|
|
|
|
|
|
|
sampled_policy_probs = torch.where(mask, sampled_policy_probs, torch.zeros_like(sampled_policy_probs)) |
|
sampled_target_policy = torch.where(mask, sampled_target_policy, torch.zeros_like(sampled_target_policy)) |
|
|
|
if self._cfg.policy_loss_type == 'KL': |
|
|
|
|
|
|
|
|
|
loss = torch.nn.functional.kl_div( |
|
sampled_policy_probs.log(), sampled_target_policy, reduction='none' |
|
) |
|
|
|
loss = torch.nan_to_num(loss) |
|
|
|
|
|
loss = loss * mask.float() |
|
|
|
loss = loss.sum() / mask.sum() |
|
|
|
elif self._cfg.policy_loss_type == 'cross_entropy': |
|
|
|
|
|
|
|
|
|
loss = torch.nn.functional.cross_entropy( |
|
sampled_policy_probs, torch.argmax(sampled_target_policy, dim=1), reduction='none' |
|
) |
|
|
|
|
|
loss = torch.nan_to_num(loss) |
|
|
|
|
|
loss = loss * mask.float() |
|
|
|
loss = loss.sum() / mask.sum() |
|
|
|
else: |
|
raise ValueError(f"Invalid policy_loss_type: {self._cfg.policy_loss_type}") |
|
|
|
return loss |
|
|
|
def _init_collect(self) -> None: |
|
""" |
|
Overview: |
|
Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. |
|
""" |
|
self._get_simulation_env() |
|
|
|
self._collect_model = self._model |
|
if self._cfg.mcts_ctree: |
|
import sys |
|
sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build') |
|
import mcts_alphazero |
|
self._collect_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, self._cfg.mcts.num_simulations, |
|
self._cfg.mcts.pb_c_base, |
|
self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha, |
|
self._cfg.mcts.root_noise_weight, self.simulate_env) |
|
else: |
|
if self._cfg.sampled_algo: |
|
from lzero.mcts.ptree.ptree_az_sampled import MCTS |
|
else: |
|
from lzero.mcts.ptree.ptree_az import MCTS |
|
self._collect_mcts = MCTS(self._cfg.mcts, self.simulate_env) |
|
|
|
self.collect_mcts_temperature = 1 |
|
|
|
@torch.no_grad() |
|
def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch.Tensor]: |
|
|
|
""" |
|
Overview: |
|
The forward function for collecting data in collect mode. Use real env to execute MCTS search. |
|
Arguments: |
|
- obs (:obj:`Dict`): The dict of obs, the key is env_id and the value is the \ |
|
corresponding obs in this timestep. |
|
- temperature (:obj:`float`): The temperature for MCTS search. |
|
Returns: |
|
- output (:obj:`Dict[str, torch.Tensor]`): The dict of output, the key is env_id and the value is the \ |
|
the corresponding policy output in this timestep, including action, probs and so on. |
|
""" |
|
self.collect_mcts_temperature = temperature |
|
ready_env_id = list(obs.keys()) |
|
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id} |
|
try: |
|
katago_game_state = {env_id: obs[env_id]['katago_game_state'] for env_id in ready_env_id} |
|
except Exception as e: |
|
katago_game_state = {env_id: None for env_id in ready_env_id} |
|
|
|
start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id} |
|
output = {} |
|
self._policy_model = self._collect_model |
|
for env_id in ready_env_id: |
|
|
|
|
|
state_config_for_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id], |
|
init_state=init_state[env_id], |
|
katago_policy_init=True, |
|
katago_game_state=katago_game_state[env_id])) |
|
|
|
action, mcts_visit_count_probs = self._collect_mcts.get_next_action( |
|
state_config_for_env_reset, |
|
self._policy_value_func, |
|
self.collect_mcts_temperature, |
|
True, |
|
) |
|
|
|
|
|
|
|
output[env_id] = { |
|
'action': action, |
|
'probs': mcts_visit_count_probs, |
|
'root_sampled_actions': self._collect_mcts.get_sampled_actions(), |
|
} |
|
|
|
return output |
|
|
|
def _init_eval(self) -> None: |
|
""" |
|
Overview: |
|
Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. |
|
""" |
|
self._get_simulation_env() |
|
if self._cfg.mcts_ctree: |
|
import sys |
|
sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build') |
|
import mcts_alphazero |
|
|
|
self._eval_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, |
|
min(800, self._cfg.mcts.num_simulations * 4), |
|
self._cfg.mcts.pb_c_base, |
|
self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha, |
|
self._cfg.mcts.root_noise_weight, self.simulate_env) |
|
else: |
|
if self._cfg.sampled_algo: |
|
from lzero.mcts.ptree.ptree_az_sampled import MCTS |
|
else: |
|
from lzero.mcts.ptree.ptree_az import MCTS |
|
mcts_eval_config = copy.deepcopy(self._cfg.mcts) |
|
|
|
mcts_eval_config.num_simulations = min(800, mcts_eval_config.num_simulations * 4) |
|
self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env) |
|
|
|
self._eval_model = self._model |
|
|
|
def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]: |
|
|
|
""" |
|
Overview: |
|
The forward function for evaluating the current policy in eval mode, similar to ``self._forward_collect``. |
|
Arguments: |
|
- obs (:obj:`Dict`): The dict of obs, the key is env_id and the value is the \ |
|
corresponding obs in this timestep. |
|
Returns: |
|
- output (:obj:`Dict[str, torch.Tensor]`): The dict of output, the key is env_id and the value is the \ |
|
the corresponding policy output in this timestep, including action, probs and so on. |
|
""" |
|
ready_env_id = list(obs.keys()) |
|
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id} |
|
try: |
|
katago_game_state = {env_id: obs[env_id]['katago_game_state'] for env_id in ready_env_id} |
|
except Exception as e: |
|
katago_game_state = {env_id: None for env_id in ready_env_id} |
|
|
|
start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id} |
|
output = {} |
|
self._policy_model = self._eval_model |
|
for env_id in ready_env_id: |
|
|
|
|
|
|
|
state_config_for_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id], |
|
init_state=init_state[env_id], |
|
katago_policy_init=False, |
|
katago_game_state=katago_game_state[env_id])) |
|
|
|
|
|
action, mcts_visit_count_probs = self._eval_mcts.get_next_action(state_config_for_env_reset, |
|
self._policy_value_func, |
|
1.0, False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output[env_id] = { |
|
'action': action, |
|
'probs': mcts_visit_count_probs, |
|
} |
|
return output |
|
|
|
def _get_simulation_env(self): |
|
assert self._cfg.simulation_env_name in ['tictactoe', 'gomoku', 'go'], self._cfg.simulation_env_name |
|
assert self._cfg.simulation_env_config_type in ['play_with_bot', 'self_play', 'league', |
|
'sampled_play_with_bot'], self._cfg.simulation_env_config_type |
|
if self._cfg.simulation_env_name == 'tictactoe': |
|
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv |
|
if self._cfg.simulation_env_config_type == 'play_with_bot': |
|
from zoo.board_games.tictactoe.config.tictactoe_alphazero_bot_mode_config import \ |
|
tictactoe_alphazero_config |
|
elif self._cfg.simulation_env_config_type == 'self_play': |
|
from zoo.board_games.tictactoe.config.tictactoe_alphazero_sp_mode_config import \ |
|
tictactoe_alphazero_config |
|
elif self._cfg.simulation_env_config_type == 'league': |
|
from zoo.board_games.tictactoe.config.tictactoe_alphazero_league_config import \ |
|
tictactoe_alphazero_config |
|
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot': |
|
from zoo.board_games.tictactoe.config.tictactoe_sampled_alphazero_bot_mode_config import \ |
|
tictactoe_sampled_alphazero_config as tictactoe_alphazero_config |
|
|
|
self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env) |
|
|
|
elif self._cfg.simulation_env_name == 'gomoku': |
|
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv |
|
if self._cfg.simulation_env_config_type == 'play_with_bot': |
|
from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config |
|
elif self._cfg.simulation_env_config_type == 'self_play': |
|
from zoo.board_games.gomoku.config.gomoku_alphazero_sp_mode_config import gomoku_alphazero_config |
|
elif self._cfg.simulation_env_config_type == 'league': |
|
from zoo.board_games.gomoku.config.gomoku_alphazero_league_config import gomoku_alphazero_config |
|
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot': |
|
from zoo.board_games.gomoku.config.gomoku_sampled_alphazero_bot_mode_config import \ |
|
gomoku_sampled_alphazero_config as gomoku_alphazero_config |
|
|
|
self.simulate_env = GomokuEnv(gomoku_alphazero_config.env) |
|
elif self._cfg.simulation_env_name == 'go': |
|
from zoo.board_games.go.envs.go_env import GoEnv |
|
if self._cfg.simulation_env_config_type == 'play_with_bot': |
|
from zoo.board_games.go.config.go_alphazero_bot_mode_config import go_alphazero_config |
|
elif self._cfg.simulation_env_config_type == 'self_play': |
|
from zoo.board_games.go.config.go_alphazero_sp_mode_config import go_alphazero_config |
|
elif self._cfg.simulation_env_config_type == 'league': |
|
from zoo.board_games.go.config.go_alphazero_league_config import go_alphazero_config |
|
elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot': |
|
from zoo.board_games.go.config.go_sampled_alphazero_bot_mode_config import \ |
|
go_sampled_alphazero_config as go_alphazero_config |
|
|
|
self.simulate_env = GoEnv(go_alphazero_config.env) |
|
|
|
@torch.no_grad() |
|
def _policy_value_func(self, environment: 'Environment') -> Tuple[Dict[int, np.ndarray], float]: |
|
|
|
legal_actions = environment.legal_actions |
|
|
|
|
|
current_state, state_scale = environment.current_state() |
|
|
|
|
|
state_scale_tensor = torch.from_numpy(state_scale).to( |
|
device=self._device, dtype=torch.float |
|
).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
action_probabilities, state_value = self._policy_model.compute_policy_value(state_scale_tensor) |
|
|
|
|
|
legal_action_probabilities = dict( |
|
zip(legal_actions, action_probabilities.squeeze(0)[legal_actions].detach().cpu().numpy())) |
|
|
|
|
|
return legal_action_probabilities, state_value.item() |
|
|
|
def _monitor_vars_learn(self) -> List[str]: |
|
""" |
|
Overview: |
|
Register the variables to be monitored in learn mode. The registered variables will be logged in |
|
tensorboard according to the return value ``_forward_learn``. |
|
""" |
|
return super()._monitor_vars_learn() + [ |
|
'cur_lr', 'total_loss', 'policy_loss', 'value_loss', 'entropy_loss', 'total_grad_norm_before_clip', |
|
'collect_mcts_temperature' |
|
] |
|
|
|
def _process_transition(self, obs: Dict, model_output: Dict[str, torch.Tensor], timestep: namedtuple) -> Dict: |
|
""" |
|
Overview: |
|
Generate the dict type transition (one timestep) data from policy learning. |
|
""" |
|
if 'katago_game_state' in obs.keys(): |
|
del obs['katago_game_state'] |
|
|
|
|
|
|
|
|
|
return { |
|
'obs': obs, |
|
'next_obs': timestep.obs, |
|
'action': model_output['action'], |
|
'root_sampled_actions': model_output['root_sampled_actions'], |
|
'probs': model_output['probs'], |
|
'reward': timestep.reward, |
|
'done': timestep.done, |
|
} |
|
|
|
def _get_train_sample(self, data): |
|
|
|
pass |
|
|