|
from typing import List, Dict, Any, Tuple, Union |
|
import copy |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.distributions import Normal, Independent |
|
|
|
from ding.torch_utils import Adam, to_device |
|
from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ |
|
qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data |
|
from ding.model import model_wrap |
|
from ding.utils import POLICY_REGISTRY |
|
from ding.utils.data import default_collate, default_decollate |
|
from .sac import SACPolicy |
|
from .qrdqn import QRDQNPolicy |
|
from .common_utils import default_preprocess_learn |
|
|
|
|
|
@POLICY_REGISTRY.register('cql') |
|
class CQLPolicy(SACPolicy): |
|
""" |
|
Overview: |
|
Policy class of CQL algorithm for continuous control. Paper link: https://arxiv.org/abs/2006.04779. |
|
|
|
Config: |
|
== ==================== ======== ============= ================================= ======================= |
|
ID Symbol Type Default Value Description Other(Shape) |
|
== ==================== ======== ============= ================================= ======================= |
|
1 ``type`` str cql | RL policy register name, refer | this arg is optional, |
|
| to registry ``POLICY_REGISTRY`` | a placeholder |
|
2 ``cuda`` bool True | Whether to use cuda for network | |
|
3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for |
|
| ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ |
|
| | buffer when training starts. | TD3. |
|
4 | ``model.policy_`` int 256 | Linear layer size for policy | |
|
| ``embedding_size`` | network. | |
|
5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | |
|
| ``embedding_size`` | network. | |
|
6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when |
|
| ``embedding_size`` | network. | model.value_network |
|
| | | is False. |
|
7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when |
|
| ``_rate_q`` | network. | model.value_network |
|
| | | is True. |
|
8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when |
|
| ``_rate_policy`` | network. | model.value_network |
|
| | | is True. |
|
9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when |
|
| ``_rate_value`` | network. | model.value_network |
|
| | | is False. |
|
10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- |
|
| | coefficient. | zation for auto |
|
| | | `alpha`, when |
|
| | | auto_alpha is True |
|
11 | ``learn.repara_`` bool True | Determine whether to use | |
|
| ``meterization`` | reparameterization trick. | |
|
12 | ``learn.`` bool False | Determine whether to use | Temperature parameter |
|
| ``auto_alpha`` | auto temperature parameter | determines the |
|
| | `alpha`. | relative importance |
|
| | | of the entropy term |
|
| | | against the reward. |
|
13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only |
|
| ``ignore_done`` | done flag. | in halfcheetah env. |
|
14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation |
|
| ``target_theta`` | target network. | factor in polyak aver |
|
| | | aging for target |
|
| | | networks. |
|
== ==================== ======== ============= ================================= ======================= |
|
""" |
|
|
|
config = dict( |
|
|
|
type='cql', |
|
|
|
cuda=False, |
|
|
|
|
|
on_policy=False, |
|
|
|
priority=False, |
|
|
|
priority_IS_weight=False, |
|
|
|
random_collect_size=10000, |
|
model=dict( |
|
|
|
|
|
|
|
twin_critic=True, |
|
|
|
action_space='reparameterization', |
|
|
|
actor_head_hidden_size=256, |
|
|
|
critic_head_hidden_size=256, |
|
), |
|
|
|
learn=dict( |
|
|
|
|
|
update_per_collect=1, |
|
|
|
batch_size=256, |
|
|
|
learning_rate_q=3e-4, |
|
|
|
learning_rate_policy=3e-4, |
|
|
|
learning_rate_alpha=3e-4, |
|
|
|
|
|
target_theta=0.005, |
|
|
|
discount_factor=0.99, |
|
|
|
|
|
|
|
|
|
alpha=0.2, |
|
|
|
|
|
|
|
|
|
|
|
auto_alpha=True, |
|
|
|
log_space=True, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ignore_done=False, |
|
|
|
init_w=3e-3, |
|
|
|
num_actions=10, |
|
|
|
with_lagrange=False, |
|
|
|
lagrange_thresh=-1, |
|
|
|
min_q_weight=1.0, |
|
|
|
with_q_entropy=False, |
|
), |
|
eval=dict(), |
|
) |
|
|
|
def _init_learn(self) -> None: |
|
""" |
|
Overview: |
|
Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ |
|
contains three optimizers, algorithm-specific arguments such as gamma, min_q_weight, with_lagrange and \ |
|
with_q_entropy, main and target model. Especially, the ``auto_alpha`` mechanism for balancing max entropy \ |
|
target is also initialized here. |
|
This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. |
|
|
|
.. note:: |
|
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ |
|
and ``_load_state_dict_learn`` methods. |
|
|
|
.. note:: |
|
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. |
|
|
|
.. note:: |
|
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ |
|
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. |
|
""" |
|
self._priority = self._cfg.priority |
|
self._priority_IS_weight = self._cfg.priority_IS_weight |
|
self._twin_critic = self._cfg.model.twin_critic |
|
self._num_actions = self._cfg.learn.num_actions |
|
|
|
self._min_q_version = 3 |
|
self._min_q_weight = self._cfg.learn.min_q_weight |
|
self._with_lagrange = self._cfg.learn.with_lagrange and (self._lagrange_thresh > 0) |
|
self._lagrange_thresh = self._cfg.learn.lagrange_thresh |
|
if self._with_lagrange: |
|
self.target_action_gap = self._lagrange_thresh |
|
self.log_alpha_prime = torch.tensor(0.).to(self._device).requires_grad_() |
|
self.alpha_prime_optimizer = Adam( |
|
[self.log_alpha_prime], |
|
lr=self._cfg.learn.learning_rate_q, |
|
) |
|
|
|
self._with_q_entropy = self._cfg.learn.with_q_entropy |
|
|
|
|
|
init_w = self._cfg.learn.init_w |
|
self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w) |
|
self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w) |
|
self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w) |
|
self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w) |
|
if self._twin_critic: |
|
self._model.critic_head[0][-1].last.weight.data.uniform_(-init_w, init_w) |
|
self._model.critic_head[0][-1].last.bias.data.uniform_(-init_w, init_w) |
|
self._model.critic_head[1][-1].last.weight.data.uniform_(-init_w, init_w) |
|
self._model.critic_head[1][-1].last.bias.data.uniform_(-init_w, init_w) |
|
else: |
|
self._model.critic_head[2].last.weight.data.uniform_(-init_w, init_w) |
|
self._model.critic_head[-1].last.bias.data.uniform_(-init_w, init_w) |
|
|
|
|
|
self._optimizer_q = Adam( |
|
self._model.critic.parameters(), |
|
lr=self._cfg.learn.learning_rate_q, |
|
) |
|
self._optimizer_policy = Adam( |
|
self._model.actor.parameters(), |
|
lr=self._cfg.learn.learning_rate_policy, |
|
) |
|
|
|
|
|
self._gamma = self._cfg.learn.discount_factor |
|
|
|
if self._cfg.learn.auto_alpha: |
|
if self._cfg.learn.target_entropy is None: |
|
assert 'action_shape' in self._cfg.model, "CQL need network model with action_shape variable" |
|
self._target_entropy = -np.prod(self._cfg.model.action_shape) |
|
else: |
|
self._target_entropy = self._cfg.learn.target_entropy |
|
if self._cfg.learn.log_space: |
|
self._log_alpha = torch.log(torch.FloatTensor([self._cfg.learn.alpha])) |
|
self._log_alpha = self._log_alpha.to(self._device).requires_grad_() |
|
self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha) |
|
assert self._log_alpha.shape == torch.Size([1]) and self._log_alpha.requires_grad |
|
self._alpha = self._log_alpha.detach().exp() |
|
self._auto_alpha = True |
|
self._log_space = True |
|
else: |
|
self._alpha = torch.FloatTensor([self._cfg.learn.alpha]).to(self._device).requires_grad_() |
|
self._alpha_optim = torch.optim.Adam([self._alpha], lr=self._cfg.learn.learning_rate_alpha) |
|
self._auto_alpha = True |
|
self._log_space = False |
|
else: |
|
self._alpha = torch.tensor( |
|
[self._cfg.learn.alpha], requires_grad=False, device=self._device, dtype=torch.float32 |
|
) |
|
self._auto_alpha = False |
|
|
|
|
|
self._target_model = copy.deepcopy(self._model) |
|
self._target_model = model_wrap( |
|
self._target_model, |
|
wrapper_name='target', |
|
update_type='momentum', |
|
update_kwargs={'theta': self._cfg.learn.target_theta} |
|
) |
|
self._learn_model = model_wrap(self._model, wrapper_name='base') |
|
self._learn_model.reset() |
|
self._target_model.reset() |
|
|
|
self._forward_learn_cnt = 0 |
|
|
|
def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
""" |
|
Overview: |
|
Policy forward function of learn mode (training policy and updating parameters). Forward means \ |
|
that the policy inputs some training batch data from the offline dataset and then returns the output \ |
|
result, including various training information such as loss, action, priority. |
|
Arguments: |
|
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ |
|
training samples. For each element in list, the key of the dict is the name of data items and the \ |
|
value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ |
|
combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ |
|
dimension by some utility functions such as ``default_preprocess_learn``. \ |
|
For CQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ |
|
``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. |
|
Returns: |
|
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ |
|
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ |
|
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. |
|
|
|
.. note:: |
|
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ |
|
For the data type that not supported, the main reason is that the corresponding model does not support it. \ |
|
You can implement you own model rather than use the default model. For more information, please raise an \ |
|
issue in GitHub repo and we will continue to follow up. |
|
""" |
|
loss_dict = {} |
|
data = default_preprocess_learn( |
|
data, |
|
use_priority=self._priority, |
|
use_priority_IS_weight=self._cfg.priority_IS_weight, |
|
ignore_done=self._cfg.learn.ignore_done, |
|
use_nstep=False |
|
) |
|
if len(data.get('action').shape) == 1: |
|
data['action'] = data['action'].reshape(-1, 1) |
|
|
|
if self._cuda: |
|
data = to_device(data, self._device) |
|
|
|
self._learn_model.train() |
|
self._target_model.train() |
|
obs = data['obs'] |
|
next_obs = data['next_obs'] |
|
reward = data['reward'] |
|
done = data['done'] |
|
|
|
|
|
q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] |
|
|
|
|
|
with torch.no_grad(): |
|
(mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit'] |
|
|
|
dist = Independent(Normal(mu, sigma), 1) |
|
pred = dist.rsample() |
|
next_action = torch.tanh(pred) |
|
y = 1 - next_action.pow(2) + 1e-6 |
|
next_log_prob = dist.log_prob(pred).unsqueeze(-1) |
|
next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True) |
|
|
|
next_data = {'obs': next_obs, 'action': next_action} |
|
target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value'] |
|
|
|
if self._twin_critic: |
|
|
|
if self._with_q_entropy: |
|
target_q_value = torch.min(target_q_value[0], |
|
target_q_value[1]) - self._alpha * next_log_prob.squeeze(-1) |
|
else: |
|
target_q_value = torch.min(target_q_value[0], target_q_value[1]) |
|
else: |
|
if self._with_q_entropy: |
|
target_q_value = target_q_value - self._alpha * next_log_prob.squeeze(-1) |
|
|
|
|
|
if self._twin_critic: |
|
q_data0 = v_1step_td_data(q_value[0], target_q_value, reward, done, data['weight']) |
|
loss_dict['critic_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma) |
|
q_data1 = v_1step_td_data(q_value[1], target_q_value, reward, done, data['weight']) |
|
loss_dict['twin_critic_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma) |
|
td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2 |
|
else: |
|
q_data = v_1step_td_data(q_value, target_q_value, reward, done, data['weight']) |
|
loss_dict['critic_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma) |
|
|
|
|
|
|
|
curr_actions_tensor, curr_log_pis = self._get_policy_actions(data, self._num_actions) |
|
new_curr_actions_tensor, new_log_pis = self._get_policy_actions({'obs': next_obs}, self._num_actions) |
|
|
|
random_actions_tensor = torch.FloatTensor(curr_actions_tensor.shape).uniform_(-1, |
|
1).to(curr_actions_tensor.device) |
|
|
|
obs_repeat = obs.unsqueeze(1).repeat(1, self._num_actions, |
|
1).view(obs.shape[0] * self._num_actions, obs.shape[1]) |
|
act_repeat = data['action'].unsqueeze(1).repeat(1, self._num_actions, 1).view( |
|
data['action'].shape[0] * self._num_actions, data['action'].shape[1] |
|
) |
|
|
|
q_rand = self._get_q_value({'obs': obs_repeat, 'action': random_actions_tensor}) |
|
|
|
q_curr_actions = self._get_q_value({'obs': obs_repeat, 'action': curr_actions_tensor}) |
|
|
|
q_next_actions = self._get_q_value({'obs': obs_repeat, 'action': new_curr_actions_tensor}) |
|
|
|
|
|
cat_q1 = torch.cat([q_rand[0], q_value[0].reshape(-1, 1, 1), q_next_actions[0], q_curr_actions[0]], 1) |
|
cat_q2 = torch.cat([q_rand[1], q_value[1].reshape(-1, 1, 1), q_next_actions[1], q_curr_actions[1]], 1) |
|
std_q1 = torch.std(cat_q1, dim=1) |
|
std_q2 = torch.std(cat_q2, dim=1) |
|
if self._min_q_version == 3: |
|
|
|
random_density = np.log(0.5 ** curr_actions_tensor.shape[-1]) |
|
cat_q1 = torch.cat( |
|
[ |
|
q_rand[0] - random_density, q_next_actions[0] - new_log_pis.detach(), |
|
q_curr_actions[0] - curr_log_pis.detach() |
|
], 1 |
|
) |
|
cat_q2 = torch.cat( |
|
[ |
|
q_rand[1] - random_density, q_next_actions[1] - new_log_pis.detach(), |
|
q_curr_actions[1] - curr_log_pis.detach() |
|
], 1 |
|
) |
|
|
|
min_qf1_loss = torch.logsumexp(cat_q1, dim=1).mean() * self._min_q_weight |
|
min_qf2_loss = torch.logsumexp(cat_q2, dim=1).mean() * self._min_q_weight |
|
"""Subtract the log likelihood of data""" |
|
min_qf1_loss = min_qf1_loss - q_value[0].mean() * self._min_q_weight |
|
min_qf2_loss = min_qf2_loss - q_value[1].mean() * self._min_q_weight |
|
|
|
if self._with_lagrange: |
|
alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0) |
|
min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap) |
|
min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap) |
|
|
|
self.alpha_prime_optimizer.zero_grad() |
|
alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5 |
|
alpha_prime_loss.backward(retain_graph=True) |
|
self.alpha_prime_optimizer.step() |
|
|
|
loss_dict['critic_loss'] += min_qf1_loss |
|
if self._twin_critic: |
|
loss_dict['twin_critic_loss'] += min_qf2_loss |
|
|
|
|
|
self._optimizer_q.zero_grad() |
|
loss_dict['critic_loss'].backward(retain_graph=True) |
|
if self._twin_critic: |
|
loss_dict['twin_critic_loss'].backward() |
|
self._optimizer_q.step() |
|
|
|
|
|
(mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit'] |
|
dist = Independent(Normal(mu, sigma), 1) |
|
pred = dist.rsample() |
|
action = torch.tanh(pred) |
|
y = 1 - action.pow(2) + 1e-6 |
|
log_prob = dist.log_prob(pred).unsqueeze(-1) |
|
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) |
|
|
|
eval_data = {'obs': obs, 'action': action} |
|
new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value'] |
|
if self._twin_critic: |
|
new_q_value = torch.min(new_q_value[0], new_q_value[1]) |
|
|
|
|
|
policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean() |
|
|
|
loss_dict['policy_loss'] = policy_loss |
|
|
|
|
|
self._optimizer_policy.zero_grad() |
|
loss_dict['policy_loss'].backward() |
|
self._optimizer_policy.step() |
|
|
|
|
|
if self._auto_alpha: |
|
if self._log_space: |
|
log_prob = log_prob + self._target_entropy |
|
loss_dict['alpha_loss'] = -(self._log_alpha * log_prob.detach()).mean() |
|
|
|
self._alpha_optim.zero_grad() |
|
loss_dict['alpha_loss'].backward() |
|
self._alpha_optim.step() |
|
self._alpha = self._log_alpha.detach().exp() |
|
else: |
|
log_prob = log_prob + self._target_entropy |
|
loss_dict['alpha_loss'] = -(self._alpha * log_prob.detach()).mean() |
|
|
|
self._alpha_optim.zero_grad() |
|
loss_dict['alpha_loss'].backward() |
|
self._alpha_optim.step() |
|
self._alpha = max(0, self._alpha) |
|
|
|
loss_dict['total_loss'] = sum(loss_dict.values()) |
|
|
|
|
|
|
|
|
|
self._forward_learn_cnt += 1 |
|
|
|
self._target_model.update(self._learn_model.state_dict()) |
|
return { |
|
'cur_lr_q': self._optimizer_q.defaults['lr'], |
|
'cur_lr_p': self._optimizer_policy.defaults['lr'], |
|
'priority': td_error_per_sample.abs().tolist(), |
|
'td_error': td_error_per_sample.detach().mean().item(), |
|
'alpha': self._alpha.item(), |
|
'target_q_value': target_q_value.detach().mean().item(), |
|
**loss_dict |
|
} |
|
|
|
def _get_policy_actions(self, data: Dict, num_actions: int = 10, epsilon: float = 1e-6) -> List: |
|
|
|
obs = data['obs'] |
|
obs = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1]) |
|
(mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit'] |
|
dist = Independent(Normal(mu, sigma), 1) |
|
pred = dist.rsample() |
|
action = torch.tanh(pred) |
|
|
|
|
|
y = 1 - action.pow(2) + epsilon |
|
log_prob = dist.log_prob(pred).unsqueeze(-1) |
|
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) |
|
|
|
return action, log_prob.view(-1, num_actions, 1) |
|
|
|
def _get_q_value(self, data: Dict, keep: bool = True) -> torch.Tensor: |
|
new_q_value = self._learn_model.forward(data, mode='compute_critic')['q_value'] |
|
if self._twin_critic: |
|
new_q_value = [value.view(-1, self._num_actions, 1) for value in new_q_value] |
|
else: |
|
new_q_value = new_q_value.view(-1, self._num_actions, 1) |
|
if self._twin_critic and not keep: |
|
new_q_value = torch.min(new_q_value[0], new_q_value[1]) |
|
return new_q_value |
|
|
|
|
|
@POLICY_REGISTRY.register('discrete_cql') |
|
class DiscreteCQLPolicy(QRDQNPolicy): |
|
""" |
|
Overview: |
|
Policy class of discrete CQL algorithm in discrete action space environments. |
|
Paper link: https://arxiv.org/abs/2006.04779. |
|
""" |
|
|
|
config = dict( |
|
|
|
type='discrete_cql', |
|
|
|
cuda=False, |
|
|
|
on_policy=False, |
|
|
|
priority=False, |
|
|
|
discount_factor=0.97, |
|
|
|
nstep=1, |
|
|
|
learn=dict( |
|
|
|
|
|
update_per_collect=1, |
|
|
|
batch_size=64, |
|
|
|
learning_rate=0.001, |
|
|
|
target_update_freq=100, |
|
|
|
ignore_done=False, |
|
|
|
min_q_weight=1.0, |
|
), |
|
eval=dict(), |
|
) |
|
|
|
def _init_learn(self) -> None: |
|
""" |
|
Overview: |
|
Initialize the learn mode of policy, including related attributes and modules. For DiscreteCQL, it mainly \ |
|
contains the optimizer, algorithm-specific arguments such as gamma, nstep and min_q_weight, main and \ |
|
target model. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. |
|
|
|
.. note:: |
|
For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ |
|
and ``_load_state_dict_learn`` methods. |
|
|
|
.. note:: |
|
For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. |
|
|
|
.. note:: |
|
If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ |
|
with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. |
|
""" |
|
self._min_q_weight = self._cfg.learn.min_q_weight |
|
self._priority = self._cfg.priority |
|
|
|
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) |
|
|
|
self._gamma = self._cfg.discount_factor |
|
self._nstep = self._cfg.nstep |
|
|
|
|
|
self._target_model = copy.deepcopy(self._model) |
|
self._target_model = model_wrap( |
|
self._target_model, |
|
wrapper_name='target', |
|
update_type='assign', |
|
update_kwargs={'freq': self._cfg.learn.target_update_freq} |
|
) |
|
self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') |
|
self._learn_model.reset() |
|
self._target_model.reset() |
|
|
|
def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
""" |
|
Overview: |
|
Policy forward function of learn mode (training policy and updating parameters). Forward means \ |
|
that the policy inputs some training batch data from the offline dataset and then returns the output \ |
|
result, including various training information such as loss, action, priority. |
|
Arguments: |
|
- data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ |
|
training samples. For each element in list, the key of the dict is the name of data items and the \ |
|
value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ |
|
combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ |
|
dimension by some utility functions such as ``default_preprocess_learn``. \ |
|
For DiscreteCQL, each element in list is a dict containing at least the following keys: ``obs``, \ |
|
``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight`` \ |
|
and ``value_gamma`` for nstep return computation. |
|
Returns: |
|
- info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ |
|
recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ |
|
detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. |
|
|
|
.. note:: |
|
The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ |
|
For the data type that not supported, the main reason is that the corresponding model does not support it. \ |
|
You can implement you own model rather than use the default model. For more information, please raise an \ |
|
issue in GitHub repo and we will continue to follow up. |
|
""" |
|
data = default_preprocess_learn( |
|
data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True |
|
) |
|
if self._cuda: |
|
data = to_device(data, self._device) |
|
if data['action'].dim() == 2 and data['action'].shape[-1] == 1: |
|
data['action'] = data['action'].squeeze(-1) |
|
|
|
|
|
|
|
self._learn_model.train() |
|
self._target_model.train() |
|
|
|
ret = self._learn_model.forward(data['obs']) |
|
q_value, tau = ret['q'], ret['tau'] |
|
|
|
with torch.no_grad(): |
|
target_q_value = self._target_model.forward(data['next_obs'])['q'] |
|
|
|
target_q_action = self._learn_model.forward(data['next_obs'])['action'] |
|
|
|
|
|
|
|
|
|
replay_action_one_hot = F.one_hot(data['action'], self._cfg.model.action_shape) |
|
replay_chosen_q = (q_value.mean(-1) * replay_action_one_hot).sum(dim=1) |
|
|
|
dataset_expec = replay_chosen_q.mean() |
|
|
|
negative_sampling = torch.logsumexp(q_value.mean(-1), dim=1).mean() |
|
|
|
min_q_loss = negative_sampling - dataset_expec |
|
|
|
data_n = qrdqn_nstep_td_data( |
|
q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], tau, data['weight'] |
|
) |
|
value_gamma = data.get('value_gamma') |
|
loss, td_error_per_sample = qrdqn_nstep_td_error( |
|
data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma |
|
) |
|
|
|
loss += self._min_q_weight * min_q_loss |
|
|
|
|
|
|
|
|
|
self._optimizer.zero_grad() |
|
loss.backward() |
|
if self._cfg.multi_gpu: |
|
self.sync_gradients(self._learn_model) |
|
self._optimizer.step() |
|
|
|
|
|
|
|
|
|
self._target_model.update(self._learn_model.state_dict()) |
|
return { |
|
'cur_lr': self._optimizer.defaults['lr'], |
|
'total_loss': loss.item(), |
|
'priority': td_error_per_sample.abs().tolist(), |
|
'q_target': target_q_value.mean().item(), |
|
'q_value': q_value.mean().item(), |
|
|
|
|
|
} |
|
|
|
def _monitor_vars_learn(self) -> List[str]: |
|
""" |
|
Overview: |
|
Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ |
|
as text logger, tensorboard logger, will use these keys to save the corresponding data. |
|
Returns: |
|
- necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. |
|
""" |
|
return ['cur_lr', 'total_loss', 'q_target', 'q_value'] |
|
|