|
import copy |
|
from typing import List, Dict, Any, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.optim as optim |
|
from ding.model import model_wrap |
|
from ding.torch_utils import to_tensor |
|
from ding.utils import POLICY_REGISTRY |
|
from ditk import logging |
|
from torch.distributions import Categorical, Independent, Normal |
|
from torch.nn import L1Loss |
|
|
|
from lzero.mcts import SampledEfficientZeroMCTSCtree as MCTSCtree |
|
from lzero.mcts import SampledEfficientZeroMCTSPtree as MCTSPtree |
|
from lzero.model import ImageTransforms |
|
from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ |
|
DiscreteSupport, to_torch_float_tensor, ez_network_output_unpack, select_action, negative_cosine_similarity, \ |
|
prepare_obs, \ |
|
configure_optimizers |
|
from lzero.policy.muzero import MuZeroPolicy |
|
|
|
|
|
@POLICY_REGISTRY.register('sampled_efficientzero') |
|
class SampledEfficientZeroPolicy(MuZeroPolicy): |
|
""" |
|
Overview: |
|
The policy class for Sampled EfficientZero proposed in the paper https://arxiv.org/abs/2104.06303. |
|
""" |
|
|
|
|
|
config = dict( |
|
model=dict( |
|
|
|
model_type='conv', |
|
|
|
continuous_action_space=False, |
|
|
|
|
|
observation_shape=(4, 96, 96), |
|
|
|
self_supervised_learning_loss=True, |
|
|
|
|
|
action_space_size=6, |
|
|
|
categorical_distribution=True, |
|
|
|
image_channel=1, |
|
|
|
frame_stack_num=1, |
|
|
|
|
|
support_scale=300, |
|
|
|
lstm_hidden_size=512, |
|
|
|
sigma_type='conditioned', |
|
|
|
fixed_sigma_value=0.3, |
|
|
|
bias=True, |
|
|
|
discrete_action_encoding_type='one_hot', |
|
|
|
res_connection_in_dynamics=True, |
|
|
|
norm_type='BN', |
|
), |
|
|
|
|
|
multi_gpu=False, |
|
|
|
sampled_algo=True, |
|
|
|
gumbel_algo=False, |
|
|
|
mcts_ctree=True, |
|
|
|
cuda=True, |
|
|
|
collector_env_num=8, |
|
|
|
evaluator_env_num=3, |
|
|
|
env_type='not_board_games', |
|
|
|
action_type='fixed_action_space', |
|
|
|
battle_mode='play_with_bot_mode', |
|
|
|
monitor_extra_statistics=True, |
|
|
|
game_segment_length=200, |
|
|
|
|
|
|
|
transform2string=False, |
|
|
|
gray_scale=False, |
|
|
|
use_augmentation=False, |
|
|
|
augmentation=['shift', 'intensity'], |
|
|
|
|
|
|
|
|
|
|
|
ignore_done=False, |
|
|
|
|
|
|
|
|
|
|
|
|
|
update_per_collect=None, |
|
|
|
model_update_ratio=0.1, |
|
|
|
batch_size=256, |
|
|
|
optim_type='SGD', |
|
learning_rate=0.2, |
|
|
|
|
|
|
|
|
|
init_w=3e-3, |
|
normalize_prob_of_sampled_actions=False, |
|
policy_loss_type='cross_entropy', |
|
|
|
target_update_freq=100, |
|
weight_decay=1e-4, |
|
momentum=0.9, |
|
grad_clip_value=10, |
|
|
|
|
|
n_episode=8, |
|
|
|
num_simulations=50, |
|
|
|
discount_factor=0.997, |
|
|
|
td_steps=5, |
|
|
|
num_unroll_steps=5, |
|
|
|
lstm_horizon_len=5, |
|
|
|
reward_loss_weight=1, |
|
|
|
value_loss_weight=0.25, |
|
|
|
policy_loss_weight=1, |
|
|
|
policy_entropy_loss_weight=0, |
|
|
|
ssl_loss_weight=2, |
|
|
|
cos_lr_scheduler=False, |
|
|
|
|
|
lr_piecewise_constant_decay=True, |
|
|
|
threshold_training_steps_for_final_lr=int(5e4), |
|
|
|
threshold_training_steps_for_final_temperature=int(1e5), |
|
|
|
|
|
manual_temperature_decay=False, |
|
|
|
|
|
fixed_temperature_value=0.25, |
|
|
|
use_ture_chance_label_in_chance_encoder=False, |
|
|
|
|
|
|
|
use_priority=True, |
|
|
|
|
|
priority_prob_alpha=0.6, |
|
|
|
|
|
priority_prob_beta=0.4, |
|
|
|
|
|
|
|
root_dirichlet_alpha=0.3, |
|
|
|
root_noise_weight=0.25, |
|
|
|
|
|
|
|
random_collect_episode_num=0, |
|
|
|
|
|
eps=dict( |
|
|
|
eps_greedy_exploration_in_collect=False, |
|
|
|
type='linear', |
|
|
|
start=1., |
|
|
|
end=0.05, |
|
|
|
decay=int(1e5), |
|
), |
|
) |
|
|
|
def default_model(self) -> Tuple[str, List[str]]: |
|
""" |
|
Overview: |
|
Return this algorithm default model setting. |
|
Returns: |
|
- model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. |
|
- model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. |
|
- import_names (:obj:`List[str]`): The model class path list used in this algorithm. |
|
|
|
.. note:: |
|
The user can define and use customized network model but must obey the same interface definition indicated \ |
|
by import_names path. For Sampled EfficientZero, ``lzero.model.sampled_efficientzero_model.SampledEfficientZeroModel`` |
|
""" |
|
if self._cfg.model.model_type == "conv": |
|
return 'SampledEfficientZeroModel', ['lzero.model.sampled_efficientzero_model'] |
|
elif self._cfg.model.model_type == "mlp": |
|
return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp'] |
|
else: |
|
raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) |
|
|
|
def _init_learn(self) -> None: |
|
""" |
|
Overview: |
|
Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. |
|
""" |
|
assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type |
|
if self._cfg.model.continuous_action_space: |
|
|
|
init_w = self._cfg.init_w |
|
self._model.prediction_network.fc_policy_head.mu.weight.data.uniform_(-init_w, init_w) |
|
self._model.prediction_network.fc_policy_head.mu.bias.data.uniform_(-init_w, init_w) |
|
self._model.prediction_network.fc_policy_head.log_sigma_layer.weight.data.uniform_(-init_w, init_w) |
|
try: |
|
self._model.prediction_network.fc_policy_head.log_sigma_layer.bias.data.uniform_(-init_w, init_w) |
|
except Exception as exception: |
|
logging.warning(exception) |
|
|
|
if self._cfg.optim_type == 'SGD': |
|
self._optimizer = optim.SGD( |
|
self._model.parameters(), |
|
lr=self._cfg.learning_rate, |
|
momentum=self._cfg.momentum, |
|
weight_decay=self._cfg.weight_decay, |
|
) |
|
|
|
elif self._cfg.optim_type == 'Adam': |
|
self._optimizer = optim.Adam( |
|
self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay |
|
) |
|
elif self._cfg.optim_type == 'AdamW': |
|
self._optimizer = configure_optimizers( |
|
model=self._model, |
|
weight_decay=self._cfg.weight_decay, |
|
learning_rate=self._cfg.learning_rate, |
|
device_type=self._cfg.device |
|
) |
|
|
|
if self._cfg.cos_lr_scheduler is True: |
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
self.lr_scheduler = CosineAnnealingLR(self._optimizer, 1e6, eta_min=0, last_epoch=-1) |
|
|
|
if self._cfg.lr_piecewise_constant_decay: |
|
from torch.optim.lr_scheduler import LambdaLR |
|
max_step = self._cfg.threshold_training_steps_for_final_lr |
|
|
|
lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) |
|
self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) |
|
|
|
|
|
self._target_model = copy.deepcopy(self._model) |
|
self._target_model = model_wrap( |
|
self._target_model, |
|
wrapper_name='target', |
|
update_type='assign', |
|
update_kwargs={'freq': self._cfg.target_update_freq} |
|
) |
|
self._learn_model = self._model |
|
|
|
if self._cfg.use_augmentation: |
|
self.image_transforms = ImageTransforms( |
|
self._cfg.augmentation, |
|
image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) |
|
) |
|
self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) |
|
self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) |
|
self.inverse_scalar_transform_handle = InverseScalarTransform( |
|
self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution |
|
) |
|
|
|
def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: |
|
""" |
|
Overview: |
|
The forward function for learning policy in learn mode, which is the core of the learning process. |
|
The data is sampled from replay buffer. |
|
The loss is calculated by the loss function and the loss is backpropagated to update the model. |
|
Arguments: |
|
- data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. |
|
The first tensor is the current_batch, the second tensor is the target_batch. |
|
Returns: |
|
- info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ |
|
current learning loss and learning statistics. |
|
""" |
|
self._learn_model.train() |
|
self._target_model.train() |
|
|
|
current_batch, target_batch = data |
|
|
|
|
|
|
|
obs_batch_ori, action_batch, child_sampled_actions_batch, mask_batch, indices, weights, make_time = current_batch |
|
target_value_prefix, target_value, target_policy = target_batch |
|
|
|
obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) |
|
|
|
|
|
if self._cfg.use_augmentation: |
|
obs_batch = self.image_transforms.transform(obs_batch) |
|
if self._cfg.model.self_supervised_learning_loss: |
|
obs_target_batch = self.image_transforms.transform(obs_target_batch) |
|
|
|
|
|
|
|
action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float() |
|
data_list = [ |
|
mask_batch, |
|
target_value_prefix.astype('float32'), |
|
target_value.astype('float32'), target_policy, weights |
|
] |
|
[mask_batch, target_value_prefix, target_value, target_policy, |
|
weights] = to_torch_float_tensor(data_list, self._cfg.device) |
|
|
|
|
|
|
|
|
|
child_sampled_actions_batch = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device) |
|
|
|
target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) |
|
target_value = target_value.view(self._cfg.batch_size, -1) |
|
|
|
assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0) |
|
|
|
|
|
|
|
transformed_target_value_prefix = scalar_transform(target_value_prefix) |
|
transformed_target_value = scalar_transform(target_value) |
|
|
|
|
|
target_value_prefix_categorical = phi_transform(self.reward_support, transformed_target_value_prefix) |
|
target_value_categorical = phi_transform(self.value_support, transformed_target_value) |
|
|
|
|
|
|
|
|
|
network_output = self._learn_model.initial_inference(obs_batch) |
|
|
|
latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack(network_output) |
|
|
|
|
|
|
|
original_value = self.inverse_scalar_transform_handle(value) |
|
|
|
|
|
predicted_value_prefixs = [] |
|
if self._cfg.monitor_extra_statistics: |
|
latent_state_list = latent_state.detach().cpu().numpy() |
|
predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( |
|
policy_logits, dim=1 |
|
).detach().cpu() |
|
|
|
|
|
value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) |
|
value_priority = value_priority.data.cpu().numpy() + 1e-6 |
|
|
|
|
|
|
|
|
|
value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) |
|
|
|
policy_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) |
|
|
|
|
|
|
|
if self._cfg.model.continuous_action_space: |
|
"""continuous action space""" |
|
policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont( |
|
policy_loss, policy_logits, target_policy, mask_batch, child_sampled_actions_batch, unroll_step=0 |
|
) |
|
else: |
|
"""discrete action space""" |
|
policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions = self._calculate_policy_loss_disc( |
|
policy_loss, policy_logits, target_policy, mask_batch, child_sampled_actions_batch, unroll_step=0 |
|
) |
|
|
|
value_prefix_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) |
|
consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) |
|
|
|
|
|
|
|
|
|
for step_k in range(self._cfg.num_unroll_steps): |
|
|
|
|
|
|
|
network_output = self._learn_model.recurrent_inference( |
|
latent_state, reward_hidden_state, action_batch[:, step_k] |
|
) |
|
latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack( |
|
network_output |
|
) |
|
|
|
|
|
|
|
original_value = self.inverse_scalar_transform_handle(value) |
|
|
|
if self._cfg.model.self_supervised_learning_loss: |
|
|
|
|
|
|
|
if self._cfg.ssl_loss_weight > 0: |
|
|
|
beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) |
|
network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) |
|
|
|
latent_state = to_tensor(latent_state) |
|
representation_state = to_tensor(network_output.latent_state) |
|
|
|
|
|
dynamic_proj = self._learn_model.project(latent_state, with_grad=True) |
|
observation_proj = self._learn_model.project(representation_state, with_grad=False) |
|
temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] |
|
|
|
consistency_loss += temp_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self._cfg.model.continuous_action_space: |
|
"""continuous action space""" |
|
policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont( |
|
policy_loss, |
|
policy_logits, |
|
target_policy, |
|
mask_batch, |
|
child_sampled_actions_batch, |
|
unroll_step=step_k + 1 |
|
) |
|
else: |
|
"""discrete action space""" |
|
policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions = self._calculate_policy_loss_disc( |
|
policy_loss, |
|
policy_logits, |
|
target_policy, |
|
mask_batch, |
|
child_sampled_actions_batch, |
|
unroll_step=step_k + 1 |
|
) |
|
|
|
value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) |
|
value_prefix_loss += cross_entropy_loss(value_prefix, target_value_prefix_categorical[:, step_k]) |
|
|
|
|
|
if (step_k + 1) % self._cfg.lstm_horizon_len == 0: |
|
reward_hidden_state = ( |
|
torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device), |
|
torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device) |
|
) |
|
|
|
if self._cfg.monitor_extra_statistics: |
|
original_value_prefixs = self.inverse_scalar_transform_handle(value_prefix) |
|
original_value_prefixs_cpu = original_value_prefixs.detach().cpu() |
|
|
|
predicted_values = torch.cat( |
|
(predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) |
|
) |
|
predicted_value_prefixs.append(original_value_prefixs_cpu) |
|
predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) |
|
latent_state_list = np.concatenate((latent_state_list, latent_state.detach().cpu().numpy())) |
|
|
|
|
|
|
|
|
|
|
|
loss = ( |
|
self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + |
|
self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss + |
|
self._cfg.policy_entropy_loss_weight * policy_entropy_loss |
|
) |
|
weighted_total_loss = (weights * loss).mean() |
|
|
|
gradient_scale = 1 / self._cfg.num_unroll_steps |
|
weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) |
|
self._optimizer.zero_grad() |
|
weighted_total_loss.backward() |
|
if self._cfg.multi_gpu: |
|
self.sync_gradients(self._learn_model) |
|
total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( |
|
self._learn_model.parameters(), self._cfg.grad_clip_value |
|
) |
|
self._optimizer.step() |
|
if self._cfg.cos_lr_scheduler or self._cfg.lr_piecewise_constant_decay: |
|
self.lr_scheduler.step() |
|
|
|
|
|
|
|
|
|
self._target_model.update(self._learn_model.state_dict()) |
|
|
|
if self._cfg.monitor_extra_statistics: |
|
predicted_value_prefixs = torch.stack(predicted_value_prefixs).transpose(1, 0).squeeze(-1) |
|
predicted_value_prefixs = predicted_value_prefixs.reshape(-1).unsqueeze(-1) |
|
|
|
return_data = { |
|
'cur_lr': self._optimizer.param_groups[0]['lr'], |
|
'collect_mcts_temperature': self._collect_mcts_temperature, |
|
'weighted_total_loss': weighted_total_loss.item(), |
|
'total_loss': loss.mean().item(), |
|
'policy_loss': policy_loss.mean().item(), |
|
'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), |
|
'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), |
|
'value_prefix_loss': value_prefix_loss.mean().item(), |
|
'value_loss': value_loss.mean().item(), |
|
'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, |
|
|
|
|
|
|
|
|
|
'value_priority': value_priority.flatten().mean().item(), |
|
'value_priority_orig': value_priority, |
|
'target_value_prefix': target_value_prefix.detach().cpu().numpy().mean().item(), |
|
'target_value': target_value.detach().cpu().numpy().mean().item(), |
|
'transformed_target_value_prefix': transformed_target_value_prefix.detach().cpu().numpy().mean().item(), |
|
'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), |
|
'predicted_value_prefixs': predicted_value_prefixs.detach().cpu().numpy().mean().item(), |
|
'predicted_values': predicted_values.detach().cpu().numpy().mean().item() |
|
} |
|
|
|
if self._cfg.model.continuous_action_space: |
|
return_data.update({ |
|
|
|
|
|
|
|
'policy_mu_max': mu[:, 0].max().item(), |
|
'policy_mu_min': mu[:, 0].min().item(), |
|
'policy_mu_mean': mu[:, 0].mean().item(), |
|
'policy_sigma_max': sigma.max().item(), |
|
'policy_sigma_min': sigma.min().item(), |
|
'policy_sigma_mean': sigma.mean().item(), |
|
|
|
'target_sampled_actions_max': target_sampled_actions[:, :, 0].max().item(), |
|
'target_sampled_actions_min': target_sampled_actions[:, :, 0].min().item(), |
|
'target_sampled_actions_mean': target_sampled_actions[:, :, 0].mean().item(), |
|
'total_grad_norm_before_clip': total_grad_norm_before_clip.item() |
|
}) |
|
else: |
|
return_data.update({ |
|
|
|
|
|
|
|
|
|
'target_sampled_actions_max': target_sampled_actions[:, :].float().max().item(), |
|
'target_sampled_actions_min': target_sampled_actions[:, :].float().min().item(), |
|
'target_sampled_actions_mean': target_sampled_actions[:, :].float().mean().item(), |
|
'total_grad_norm_before_clip': total_grad_norm_before_clip.item() |
|
}) |
|
|
|
return return_data |
|
|
|
def _calculate_policy_loss_cont( |
|
self, policy_loss: torch.Tensor, policy_logits: torch.Tensor, target_policy: torch.Tensor, |
|
mask_batch: torch.Tensor, child_sampled_actions_batch: torch.Tensor, unroll_step: int |
|
) -> Tuple[torch.Tensor]: |
|
""" |
|
Overview: |
|
Calculate the policy loss for continuous action space. |
|
Arguments: |
|
- policy_loss (:obj:`torch.Tensor`): The policy loss tensor. |
|
- policy_logits (:obj:`torch.Tensor`): The policy logits tensor. |
|
- target_policy (:obj:`torch.Tensor`): The target policy tensor. |
|
- mask_batch (:obj:`torch.Tensor`): The mask tensor. |
|
- child_sampled_actions_batch (:obj:`torch.Tensor`): The child sampled actions tensor. |
|
- unroll_step (:obj:`int`): The unroll step. |
|
Returns: |
|
- policy_loss (:obj:`torch.Tensor`): The policy loss tensor. |
|
- policy_entropy (:obj:`torch.Tensor`): The policy entropy tensor. |
|
- policy_entropy_loss (:obj:`torch.Tensor`): The policy entropy loss tensor. |
|
- target_policy_entropy (:obj:`torch.Tensor`): The target policy entropy tensor. |
|
- target_sampled_actions (:obj:`torch.Tensor`): The target sampled actions tensor. |
|
- mu (:obj:`torch.Tensor`): The mu tensor. |
|
- sigma (:obj:`torch.Tensor`): The sigma tensor. |
|
""" |
|
(mu, sigma |
|
) = policy_logits[:, :self._cfg.model.action_space_size], policy_logits[:, -self._cfg.model.action_space_size:] |
|
|
|
dist = Independent(Normal(mu, sigma), 1) |
|
|
|
|
|
target_normalized_visit_count = target_policy[:, unroll_step] |
|
|
|
|
|
non_masked_indices = torch.nonzero(mask_batch[:, unroll_step]).squeeze(-1) |
|
|
|
if len(non_masked_indices) > 0: |
|
target_normalized_visit_count_masked = torch.index_select( |
|
target_normalized_visit_count, 0, non_masked_indices |
|
) |
|
target_dist = Categorical(target_normalized_visit_count_masked) |
|
target_policy_entropy = target_dist.entropy().mean() |
|
else: |
|
|
|
target_policy_entropy = 0 |
|
|
|
|
|
|
|
target_sampled_actions = child_sampled_actions_batch[:, unroll_step] |
|
|
|
policy_entropy = dist.entropy().mean() |
|
policy_entropy_loss = -dist.entropy() |
|
|
|
|
|
|
|
|
|
|
|
target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) |
|
log_prob_sampled_actions = [] |
|
for k in range(self._cfg.model.num_of_sampled_actions): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y = 1 - target_sampled_actions[:, k, :].pow(2) |
|
|
|
|
|
min_val = torch.tensor(-1 + 1e-6).to(target_sampled_actions.device) |
|
max_val = torch.tensor(1 - 1e-6).to(target_sampled_actions.device) |
|
target_sampled_actions_clamped = torch.clamp(target_sampled_actions[:, k, :], min_val, max_val) |
|
target_sampled_actions_before_tanh = torch.arctanh(target_sampled_actions_clamped) |
|
|
|
|
|
log_prob = dist.log_prob(target_sampled_actions_before_tanh).unsqueeze(-1) |
|
log_prob = log_prob - torch.log(y + 1e-6).sum(-1, keepdim=True) |
|
log_prob = log_prob.squeeze(-1) |
|
|
|
log_prob_sampled_actions.append(log_prob) |
|
|
|
|
|
log_prob_sampled_actions = torch.stack(log_prob_sampled_actions, dim=-1) |
|
|
|
if self._cfg.normalize_prob_of_sampled_actions: |
|
|
|
prob_sampled_actions_norm = torch.exp(log_prob_sampled_actions) / torch.exp(log_prob_sampled_actions).sum( |
|
-1 |
|
).unsqueeze(-1).repeat(1, log_prob_sampled_actions.shape[-1]).detach() |
|
|
|
|
|
log_prob_sampled_actions = torch.log(prob_sampled_actions_norm + 1e-6) |
|
|
|
|
|
if self._cfg.policy_loss_type == 'KL': |
|
|
|
policy_loss += ( |
|
torch.exp(target_log_prob_sampled_actions.detach()) * |
|
(target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) |
|
).sum(-1) * mask_batch[:, unroll_step] |
|
elif self._cfg.policy_loss_type == 'cross_entropy': |
|
|
|
policy_loss += -torch.sum( |
|
torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 |
|
) * mask_batch[:, unroll_step] |
|
|
|
return policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma |
|
|
|
def _calculate_policy_loss_disc( |
|
self, policy_loss: torch.Tensor, policy_logits: torch.Tensor, target_policy: torch.Tensor, |
|
mask_batch: torch.Tensor, child_sampled_actions_batch: torch.Tensor, unroll_step: int |
|
) -> Tuple[torch.Tensor]: |
|
""" |
|
Overview: |
|
Calculate the policy loss for discrete action space. |
|
Arguments: |
|
- policy_loss (:obj:`torch.Tensor`): The policy loss tensor. |
|
- policy_logits (:obj:`torch.Tensor`): The policy logits tensor. |
|
- target_policy (:obj:`torch.Tensor`): The target policy tensor. |
|
- mask_batch (:obj:`torch.Tensor`): The mask tensor. |
|
- child_sampled_actions_batch (:obj:`torch.Tensor`): The child sampled actions tensor. |
|
- unroll_step (:obj:`int`): The unroll step. |
|
Returns: |
|
- policy_loss (:obj:`torch.Tensor`): The policy loss tensor. |
|
- policy_entropy (:obj:`torch.Tensor`): The policy entropy tensor. |
|
- policy_entropy_loss (:obj:`torch.Tensor`): The policy entropy loss tensor. |
|
- target_policy_entropy (:obj:`torch.Tensor`): The target policy entropy tensor. |
|
- target_sampled_actions (:obj:`torch.Tensor`): The target sampled actions tensor. |
|
""" |
|
prob = torch.softmax(policy_logits, dim=-1) |
|
dist = Categorical(prob) |
|
|
|
|
|
target_normalized_visit_count = target_policy[:, unroll_step] |
|
|
|
|
|
target_normalized_visit_count_masked = torch.index_select( |
|
target_normalized_visit_count, 0, |
|
torch.nonzero(mask_batch[:, unroll_step]).squeeze(-1) |
|
) |
|
target_dist = Categorical(target_normalized_visit_count_masked) |
|
target_policy_entropy = target_dist.entropy().mean() |
|
|
|
|
|
|
|
target_sampled_actions = child_sampled_actions_batch[:, unroll_step] |
|
|
|
policy_entropy = dist.entropy().mean() |
|
policy_entropy_loss = -dist.entropy() |
|
|
|
|
|
|
|
|
|
|
|
target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) |
|
|
|
log_prob_sampled_actions = [] |
|
for k in range(self._cfg.model.num_of_sampled_actions): |
|
|
|
|
|
|
|
|
|
if len(target_sampled_actions.shape) == 2: |
|
target_sampled_actions = target_sampled_actions.unsqueeze(-1) |
|
|
|
log_prob = torch.log(prob.gather(-1, target_sampled_actions[:, k].long()).squeeze(-1) + 1e-6) |
|
log_prob_sampled_actions.append(log_prob) |
|
|
|
|
|
log_prob_sampled_actions = torch.stack(log_prob_sampled_actions, dim=-1) |
|
|
|
if self._cfg.normalize_prob_of_sampled_actions: |
|
|
|
prob_sampled_actions_norm = torch.exp(log_prob_sampled_actions) / torch.exp(log_prob_sampled_actions).sum( |
|
-1 |
|
).unsqueeze(-1).repeat(1, log_prob_sampled_actions.shape[-1]).detach() |
|
|
|
|
|
log_prob_sampled_actions = torch.log(prob_sampled_actions_norm + 1e-6) |
|
|
|
|
|
if self._cfg.policy_loss_type == 'KL': |
|
|
|
policy_loss += ( |
|
torch.exp(target_log_prob_sampled_actions.detach()) * |
|
(target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) |
|
).sum(-1) * mask_batch[:, unroll_step] |
|
elif self._cfg.policy_loss_type == 'cross_entropy': |
|
|
|
policy_loss += -torch.sum( |
|
torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 |
|
) * mask_batch[:, unroll_step] |
|
|
|
return policy_loss, policy_entropy, policy_entropy_loss, target_policy_entropy, target_sampled_actions |
|
|
|
def _init_collect(self) -> None: |
|
""" |
|
Overview: |
|
Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. |
|
""" |
|
self._collect_model = self._model |
|
if self._cfg.mcts_ctree: |
|
self._mcts_collect = MCTSCtree(self._cfg) |
|
else: |
|
self._mcts_collect = MCTSPtree(self._cfg) |
|
self._collect_mcts_temperature = 1 |
|
|
|
def _forward_collect( |
|
self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, |
|
epsilon: float = 0.25, ready_env_id: np.array = None, |
|
): |
|
""" |
|
Overview: |
|
The forward function for collecting data in collect mode. Use model to execute MCTS search. |
|
Choosing the action through sampling during the collect mode. |
|
Arguments: |
|
- data (:obj:`torch.Tensor`): The input data, i.e. the observation. |
|
- action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. |
|
- temperature (:obj:`float`): The temperature of the policy. |
|
- to_play (:obj:`int`): The player to play. |
|
- ready_env_id (:obj:`list`): The id of the env that is ready to collect. |
|
Shape: |
|
- data (:obj:`torch.Tensor`): |
|
- For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ |
|
S is the number of stacked frames, H is the height of the image, W is the width of the image. |
|
- For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. |
|
- action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. |
|
- temperature: :math:`(1, )`. |
|
- to_play: :math:`(N, 1)`, where N is the number of collect_env. |
|
- ready_env_id: None |
|
Returns: |
|
- output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ |
|
``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. |
|
""" |
|
self._collect_model.eval() |
|
self._collect_mcts_temperature = temperature |
|
active_collect_env_num = data.shape[0] |
|
with torch.no_grad(): |
|
|
|
network_output = self._collect_model.initial_inference(data) |
|
latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( |
|
network_output |
|
) |
|
|
|
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() |
|
latent_state_roots = latent_state_roots.detach().cpu().numpy() |
|
reward_hidden_state_roots = ( |
|
reward_hidden_state_roots[0].detach().cpu().numpy(), |
|
reward_hidden_state_roots[1].detach().cpu().numpy() |
|
) |
|
policy_logits = policy_logits.detach().cpu().numpy().tolist() |
|
|
|
if self._cfg.model.continuous_action_space is True: |
|
|
|
|
|
legal_actions = [ |
|
[-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(active_collect_env_num) |
|
] |
|
else: |
|
legal_actions = [ |
|
[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num) |
|
] |
|
|
|
if self._cfg.mcts_ctree: |
|
|
|
roots = MCTSCtree.roots( |
|
active_collect_env_num, legal_actions, self._cfg.model.action_space_size, |
|
self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space |
|
) |
|
else: |
|
|
|
roots = MCTSPtree.roots( |
|
active_collect_env_num, legal_actions, self._cfg.model.action_space_size, |
|
self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space |
|
) |
|
|
|
|
|
noises = [ |
|
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(self._cfg.model.num_of_sampled_actions) |
|
).astype(np.float32).tolist() for j in range(active_collect_env_num) |
|
] |
|
|
|
roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) |
|
self._mcts_collect.search( |
|
roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play |
|
) |
|
|
|
|
|
roots_visit_count_distributions = roots.get_distributions() |
|
roots_values = roots.get_values() |
|
roots_sampled_actions = roots.get_sampled_actions() |
|
|
|
data_id = [i for i in range(active_collect_env_num)] |
|
output = {i: None for i in data_id} |
|
if ready_env_id is None: |
|
ready_env_id = np.arange(active_collect_env_num) |
|
|
|
for i, env_id in enumerate(ready_env_id): |
|
distributions, value = roots_visit_count_distributions[i], roots_values[i] |
|
if self._cfg.mcts_ctree: |
|
|
|
root_sampled_actions = np.array([action for action in roots_sampled_actions[i]]) |
|
else: |
|
|
|
root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) |
|
|
|
|
|
|
|
action, visit_count_distribution_entropy = select_action( |
|
distributions, temperature=self._collect_mcts_temperature, deterministic=False |
|
) |
|
|
|
if self._cfg.mcts_ctree: |
|
|
|
action = np.array(roots_sampled_actions[i][action]) |
|
else: |
|
|
|
action = roots_sampled_actions[i][action].value |
|
|
|
if not self._cfg.model.continuous_action_space: |
|
if len(action.shape) == 0: |
|
action = int(action) |
|
elif len(action.shape) == 1: |
|
action = int(action[0]) |
|
|
|
output[env_id] = { |
|
'action': action, |
|
'visit_count_distributions': distributions, |
|
'root_sampled_actions': root_sampled_actions, |
|
'visit_count_distribution_entropy': visit_count_distribution_entropy, |
|
'searched_value': value, |
|
'predicted_value': pred_values[i], |
|
'predicted_policy_logits': policy_logits[i], |
|
} |
|
|
|
return output |
|
|
|
def _init_eval(self) -> None: |
|
""" |
|
Overview: |
|
Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. |
|
""" |
|
self._eval_model = self._model |
|
if self._cfg.mcts_ctree: |
|
self._mcts_eval = MCTSCtree(self._cfg) |
|
else: |
|
self._mcts_eval = MCTSPtree(self._cfg) |
|
|
|
def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,): |
|
""" |
|
Overview: |
|
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. |
|
Choosing the action with the highest value (argmax) rather than sampling during the eval mode. |
|
Arguments: |
|
- data (:obj:`torch.Tensor`): The input data, i.e. the observation. |
|
- action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. |
|
- to_play (:obj:`int`): The player to play. |
|
- ready_env_id (:obj:`list`): The id of the env that is ready to collect. |
|
Shape: |
|
- data (:obj:`torch.Tensor`): |
|
- For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ |
|
S is the number of stacked frames, H is the height of the image, W is the width of the image. |
|
- For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. |
|
- action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. |
|
- to_play: :math:`(N, 1)`, where N is the number of collect_env. |
|
- ready_env_id: None |
|
Returns: |
|
- output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ |
|
``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. |
|
""" |
|
self._eval_model.eval() |
|
active_eval_env_num = data.shape[0] |
|
with torch.no_grad(): |
|
|
|
network_output = self._eval_model.initial_inference(data) |
|
latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( |
|
network_output |
|
) |
|
|
|
if not self._eval_model.training: |
|
|
|
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() |
|
latent_state_roots = latent_state_roots.detach().cpu().numpy() |
|
reward_hidden_state_roots = ( |
|
reward_hidden_state_roots[0].detach().cpu().numpy(), |
|
reward_hidden_state_roots[1].detach().cpu().numpy() |
|
) |
|
policy_logits = policy_logits.detach().cpu().numpy().tolist() |
|
|
|
if self._cfg.model.continuous_action_space is True: |
|
|
|
|
|
legal_actions = [ |
|
[-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(active_eval_env_num) |
|
] |
|
else: |
|
legal_actions = [ |
|
[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num) |
|
] |
|
|
|
|
|
if self._cfg.mcts_ctree: |
|
roots = MCTSCtree.roots( |
|
active_eval_env_num, legal_actions, self._cfg.model.action_space_size, |
|
self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space |
|
) |
|
else: |
|
|
|
roots = MCTSPtree.roots( |
|
active_eval_env_num, legal_actions, self._cfg.model.action_space_size, |
|
self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space |
|
) |
|
|
|
roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) |
|
self._mcts_eval.search(roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play) |
|
|
|
|
|
roots_visit_count_distributions = roots.get_distributions() |
|
roots_values = roots.get_values() |
|
|
|
|
|
|
|
roots_sampled_actions = roots.get_sampled_actions( |
|
) |
|
|
|
data_id = [i for i in range(active_eval_env_num)] |
|
output = {i: None for i in data_id} |
|
|
|
if ready_env_id is None: |
|
ready_env_id = np.arange(active_eval_env_num) |
|
|
|
for i, env_id in enumerate(ready_env_id): |
|
distributions, value = roots_visit_count_distributions[i], roots_values[i] |
|
try: |
|
root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) |
|
except Exception: |
|
|
|
root_sampled_actions = np.array([action for action in roots_sampled_actions[i]]) |
|
|
|
|
|
|
|
action, visit_count_distribution_entropy = select_action( |
|
distributions, temperature=1, deterministic=True |
|
) |
|
|
|
|
|
|
|
|
|
try: |
|
action = roots_sampled_actions[i][action].value |
|
|
|
except Exception: |
|
|
|
action = np.array(roots_sampled_actions[i][action]) |
|
|
|
if not self._cfg.model.continuous_action_space: |
|
if len(action.shape) == 0: |
|
action = int(action) |
|
elif len(action.shape) == 1: |
|
action = int(action[0]) |
|
|
|
output[env_id] = { |
|
'action': action, |
|
'visit_count_distributions': distributions, |
|
'root_sampled_actions': root_sampled_actions, |
|
'visit_count_distribution_entropy': visit_count_distribution_entropy, |
|
'searched_value': value, |
|
'predicted_value': pred_values[i], |
|
'predicted_policy_logits': policy_logits[i], |
|
} |
|
|
|
return output |
|
|
|
def _monitor_vars_learn(self) -> List[str]: |
|
""" |
|
Overview: |
|
Register the variables to be monitored in learn mode. The registered variables will be logged in |
|
tensorboard according to the return value ``_forward_learn``. |
|
""" |
|
if self._cfg.model.continuous_action_space: |
|
return [ |
|
'collect_mcts_temperature', |
|
'cur_lr', |
|
'total_loss', |
|
'weighted_total_loss', |
|
'policy_loss', |
|
'value_prefix_loss', |
|
'value_loss', |
|
'consistency_loss', |
|
'value_priority', |
|
'target_value_prefix', |
|
'target_value', |
|
'predicted_value_prefixs', |
|
'predicted_values', |
|
'transformed_target_value_prefix', |
|
'transformed_target_value', |
|
|
|
|
|
|
|
|
|
'policy_entropy', |
|
'target_policy_entropy', |
|
'policy_mu_max', |
|
'policy_mu_min', |
|
'policy_mu_mean', |
|
'policy_sigma_max', |
|
'policy_sigma_min', |
|
'policy_sigma_mean', |
|
|
|
'target_sampled_actions_max', |
|
'target_sampled_actions_min', |
|
'target_sampled_actions_mean', |
|
'total_grad_norm_before_clip', |
|
] |
|
else: |
|
return [ |
|
'collect_mcts_temperature', |
|
'cur_lr', |
|
'total_loss', |
|
'weighted_total_loss', |
|
'loss_mean', |
|
'policy_loss', |
|
'value_prefix_loss', |
|
'value_loss', |
|
'consistency_loss', |
|
'value_priority', |
|
'target_value_prefix', |
|
'target_value', |
|
'predicted_value_prefixs', |
|
'predicted_values', |
|
'transformed_target_value_prefix', |
|
'transformed_target_value', |
|
|
|
|
|
|
|
|
|
'policy_entropy', |
|
'target_policy_entropy', |
|
|
|
|
|
'target_sampled_actions_max', |
|
'target_sampled_actions_min', |
|
'target_sampled_actions_mean', |
|
'total_grad_norm_before_clip', |
|
] |
|
|
|
def _state_dict_learn(self) -> Dict[str, Any]: |
|
""" |
|
Overview: |
|
Return the state_dict of learn mode, usually including model and optimizer. |
|
Returns: |
|
- state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. |
|
""" |
|
return { |
|
'model': self._learn_model.state_dict(), |
|
'target_model': self._target_model.state_dict(), |
|
'optimizer': self._optimizer.state_dict(), |
|
} |
|
|
|
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: |
|
""" |
|
Overview: |
|
Load the state_dict variable into policy learn mode. |
|
Arguments: |
|
- state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. |
|
""" |
|
self._learn_model.load_state_dict(state_dict['model']) |
|
self._target_model.load_state_dict(state_dict['target_model']) |
|
self._optimizer.load_state_dict(state_dict['optimizer']) |
|
|
|
def _process_transition(self, obs, policy_output, timestep): |
|
|
|
pass |
|
|
|
def _get_train_sample(self, data): |
|
|
|
pass |
|
|