|
from typing import List, Dict, Any, Tuple, Union |
|
import copy |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torch.distributions import Normal, Independent |
|
|
|
from ding.torch_utils import Adam, to_device |
|
from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ |
|
qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data |
|
from ding.model import model_wrap |
|
from ding.utils import POLICY_REGISTRY |
|
from ding.utils.data import default_collate, default_decollate |
|
from .sac import SACPolicy |
|
from .dqn import DQNPolicy |
|
from .common_utils import default_preprocess_learn |
|
|
|
|
|
@POLICY_REGISTRY.register('edac') |
|
class EDACPolicy(SACPolicy): |
|
""" |
|
Overview: |
|
Policy class of EDAC algorithm. https://arxiv.org/pdf/2110.01548.pdf |
|
|
|
Config: |
|
== ==================== ======== ============= ================================= ======================= |
|
ID Symbol Type Default Value Description Other(Shape) |
|
== ==================== ======== ============= ================================= ======================= |
|
1 ``type`` str td3 | RL policy register name, refer | this arg is optional, |
|
| to registry ``POLICY_REGISTRY`` | a placeholder |
|
2 ``cuda`` bool True | Whether to use cuda for network | |
|
3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for |
|
| ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ |
|
| | buffer when training starts. | TD3. |
|
4 | ``model.policy_`` int 256 | Linear layer size for policy | |
|
| ``embedding_size`` | network. | |
|
5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | |
|
| ``embedding_size`` | network. | |
|
6 | ``model.emsemble`` int 10 | Number of Q-ensemble network | |
|
| ``_num`` | | |
|
| | | is False. |
|
7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when |
|
| ``_rate_q`` | network. | model.value_network |
|
| | | is True. |
|
8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when |
|
| ``_rate_policy`` | network. | model.value_network |
|
| | | is True. |
|
9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when |
|
| ``_rate_value`` | network. | model.value_network |
|
| | | is False. |
|
10 | ``learn.alpha`` float 1.0 | Entropy regularization | alpha is initiali- |
|
| | coefficient. | zation for auto |
|
| | | `alpha`, when |
|
| | | auto_alpha is True |
|
11 | ``learn.eta`` bool True | Parameter of EDAC algorithm | Defalut to 1.0 |
|
12 | ``learn.`` bool True | Determine whether to use | Temperature parameter |
|
| ``auto_alpha`` | auto temperature parameter | determines the |
|
| | `alpha`. | relative importance |
|
| | | of the entropy term |
|
| | | against the reward. |
|
13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only |
|
| ``ignore_done`` | done flag. | in halfcheetah env. |
|
14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation |
|
| ``target_theta`` | target network. | factor in polyak aver |
|
| | | aging for target |
|
| | | networks. |
|
== ==================== ======== ============= ================================= ======================= |
|
""" |
|
config = dict( |
|
|
|
type='edac', |
|
cuda=False, |
|
on_policy=False, |
|
multi_agent=False, |
|
priority=False, |
|
priority_IS_weight=False, |
|
random_collect_size=10000, |
|
model=dict( |
|
|
|
ensemble_num=10, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
actor_head_hidden_size=256, |
|
|
|
|
|
critic_head_hidden_size=256, |
|
), |
|
learn=dict( |
|
multi_gpu=False, |
|
update_per_collect=1, |
|
batch_size=256, |
|
learning_rate_q=3e-4, |
|
learning_rate_policy=3e-4, |
|
learning_rate_value=3e-4, |
|
learning_rate_alpha=3e-4, |
|
target_theta=0.005, |
|
discount_factor=0.99, |
|
alpha=1, |
|
auto_alpha=True, |
|
|
|
log_space=True, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ignore_done=False, |
|
|
|
init_w=3e-3, |
|
|
|
min_q_weight=1.0, |
|
|
|
with_q_entropy=False, |
|
eta=0.1, |
|
), |
|
collect=dict( |
|
|
|
unroll_len=1, |
|
), |
|
eval=dict(), |
|
other=dict( |
|
replay_buffer=dict( |
|
|
|
replay_buffer_size=1000000, |
|
|
|
|
|
|
|
|
|
), |
|
), |
|
) |
|
|
|
def default_model(self) -> Tuple[str, List[str]]: |
|
""" |
|
Overview: |
|
Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \ |
|
automatically call this method to get the default model setting and create model. |
|
Returns: |
|
- model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names. |
|
""" |
|
return 'edac', ['ding.model.template.edac'] |
|
|
|
def _init_learn(self) -> None: |
|
r""" |
|
Overview: |
|
Learn mode init method. Called by ``self.__init__``. |
|
Init q, value and policy's optimizers, algorithm config, main and target models. |
|
""" |
|
super()._init_learn() |
|
|
|
self._eta = self._cfg.learn.eta |
|
self._with_q_entropy = self._cfg.learn.with_q_entropy |
|
self._forward_learn_cnt = 0 |
|
|
|
def _forward_learn(self, data: dict) -> Dict[str, Any]: |
|
loss_dict = {} |
|
data = default_preprocess_learn( |
|
data, |
|
use_priority=self._priority, |
|
use_priority_IS_weight=self._cfg.priority_IS_weight, |
|
ignore_done=self._cfg.learn.ignore_done, |
|
use_nstep=False |
|
) |
|
if len(data.get('action').shape) == 1: |
|
data['action'] = data['action'].reshape(-1, 1) |
|
|
|
if self._cuda: |
|
data = to_device(data, self._device) |
|
|
|
self._learn_model.train() |
|
self._target_model.train() |
|
obs = data['obs'] |
|
next_obs = data['next_obs'] |
|
reward = data['reward'] |
|
done = data['done'] |
|
acs = data['action'] |
|
|
|
|
|
q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] |
|
with torch.no_grad(): |
|
(mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] |
|
|
|
dist = Independent(Normal(mu, sigma), 1) |
|
pred = dist.rsample() |
|
next_action = torch.tanh(pred) |
|
y = 1 - next_action.pow(2) + 1e-6 |
|
next_log_prob = dist.log_prob(pred).unsqueeze(-1) |
|
next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True) |
|
|
|
next_data = {'obs': next_obs, 'action': next_action} |
|
target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] |
|
|
|
|
|
target_q_value, _ = torch.min(target_q_value, dim=0) |
|
if self._with_q_entropy: |
|
target_q_value -= self._alpha * next_log_prob.squeeze(-1) |
|
target_q_value = self._gamma * (1 - done) * target_q_value + reward |
|
|
|
weight = data['weight'] |
|
if weight is None: |
|
weight = torch.ones_like(q_value) |
|
td_error_per_sample = nn.MSELoss(reduction='none')(q_value, target_q_value).mean(dim=1).sum() |
|
loss_dict['critic_loss'] = (td_error_per_sample * weight).mean() |
|
|
|
|
|
if self._eta > 0: |
|
|
|
pre_obs = obs.unsqueeze(0).repeat_interleave(self._cfg.model.ensemble_num, dim=0) |
|
pre_acs = acs.unsqueeze(0).repeat_interleave(self._cfg.model.ensemble_num, dim=0).requires_grad_(True) |
|
|
|
|
|
q_pred_tile = self._learn_model.forward({ |
|
'obs': pre_obs, |
|
'action': pre_acs |
|
}, mode='compute_critic')['q_value'].requires_grad_(True) |
|
|
|
q_pred_grads = torch.autograd.grad(q_pred_tile.sum(), pre_acs, retain_graph=True, create_graph=True)[0] |
|
q_pred_grads = q_pred_grads / (torch.norm(q_pred_grads, p=2, dim=2).unsqueeze(-1) + 1e-10) |
|
|
|
q_pred_grads = q_pred_grads.transpose(0, 1) |
|
|
|
q_pred_grads = q_pred_grads @ q_pred_grads.permute(0, 2, 1) |
|
masks = torch.eye( |
|
self._cfg.model.ensemble_num, device=obs.device |
|
).unsqueeze(dim=0).repeat(q_pred_grads.size(0), 1, 1) |
|
q_pred_grads = (1 - masks) * q_pred_grads |
|
grad_loss = torch.mean(torch.sum(q_pred_grads, dim=(1, 2))) / (self._cfg.model.ensemble_num - 1) |
|
loss_dict['critic_loss'] += grad_loss * self._eta |
|
|
|
self._optimizer_q.zero_grad() |
|
loss_dict['critic_loss'].backward() |
|
self._optimizer_q.step() |
|
|
|
(mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit'] |
|
dist = Independent(Normal(mu, sigma), 1) |
|
pred = dist.rsample() |
|
action = torch.tanh(pred) |
|
y = 1 - action.pow(2) + 1e-6 |
|
log_prob = dist.log_prob(pred).unsqueeze(-1) |
|
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) |
|
|
|
eval_data = {'obs': obs, 'action': action} |
|
new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value'] |
|
new_q_value, _ = torch.min(new_q_value, dim=0) |
|
|
|
|
|
policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean() |
|
|
|
loss_dict['policy_loss'] = policy_loss |
|
|
|
|
|
self._optimizer_policy.zero_grad() |
|
loss_dict['policy_loss'].backward() |
|
self._optimizer_policy.step() |
|
|
|
|
|
if self._auto_alpha: |
|
if self._log_space: |
|
log_prob = log_prob + self._target_entropy |
|
loss_dict['alpha_loss'] = -(self._log_alpha * log_prob.detach()).mean() |
|
|
|
self._alpha_optim.zero_grad() |
|
loss_dict['alpha_loss'].backward() |
|
self._alpha_optim.step() |
|
self._alpha = self._log_alpha.detach().exp() |
|
else: |
|
log_prob = log_prob + self._target_entropy |
|
loss_dict['alpha_loss'] = -(self._alpha * log_prob.detach()).mean() |
|
|
|
self._alpha_optim.zero_grad() |
|
loss_dict['alpha_loss'].backward() |
|
self._alpha_optim.step() |
|
self._alpha = max(0, self._alpha) |
|
|
|
loss_dict['total_loss'] = sum(loss_dict.values()) |
|
|
|
|
|
|
|
|
|
self._forward_learn_cnt += 1 |
|
|
|
self._target_model.update(self._learn_model.state_dict()) |
|
return { |
|
'cur_lr_q': self._optimizer_q.defaults['lr'], |
|
'cur_lr_p': self._optimizer_policy.defaults['lr'], |
|
'priority': td_error_per_sample.abs().tolist(), |
|
'td_error': td_error_per_sample.detach().mean().item(), |
|
'alpha': self._alpha.item(), |
|
'target_q_value': target_q_value.detach().mean().item(), |
|
**loss_dict |
|
} |
|
|