|
from typing import Dict, Any |
|
import torch |
|
from ding.torch_utils import to_device |
|
from ding.rl_utils import dist_nstep_td_data, dist_nstep_td_error, dist_1step_td_data, dist_1step_td_error |
|
from ding.policy import RainbowDQNPolicy |
|
from ding.utils import POLICY_REGISTRY |
|
from ding.policy.common_utils import default_preprocess_learn |
|
|
|
|
|
@POLICY_REGISTRY.register('md_rainbow_dqn') |
|
class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy): |
|
r""" |
|
Overview: |
|
Multi-discrete action space Rainbow DQN algorithms. |
|
""" |
|
|
|
def _forward_learn(self, data: dict) -> Dict[str, Any]: |
|
""" |
|
Overview: |
|
Forward and backward function of learn mode, acquire the data and calculate the loss and \ |
|
optimize learner model |
|
|
|
Arguments: |
|
- data (:obj:`dict`): Dict type data, including at least ['obs', 'next_obs', 'reward', 'action'] |
|
|
|
Returns: |
|
- info_dict (:obj:`Dict[str, Any]`): Including cur_lr, total_loss and priority |
|
- cur_lr (:obj:`float`): current learning rate |
|
- total_loss (:obj:`float`): the calculated loss |
|
- priority (:obj:`list`): the priority of samples |
|
""" |
|
data = default_preprocess_learn( |
|
data, |
|
use_priority=self._priority, |
|
use_priority_IS_weight=self._cfg.priority_IS_weight, |
|
ignore_done=self._cfg.learn.ignore_done, |
|
use_nstep=True |
|
) |
|
if self._cuda: |
|
data = to_device(data, self._device) |
|
|
|
|
|
|
|
self._learn_model.train() |
|
self._target_model.train() |
|
|
|
self._reset_noise(self._learn_model) |
|
self._reset_noise(self._target_model) |
|
q_dist = self._learn_model.forward(data['obs'])['distribution'] |
|
with torch.no_grad(): |
|
target_q_dist = self._target_model.forward(data['next_obs'])['distribution'] |
|
self._reset_noise(self._learn_model) |
|
target_q_action = self._learn_model.forward(data['next_obs'])['action'] |
|
|
|
value_gamma = data.get('value_gamma', None) |
|
if isinstance(q_dist, torch.Tensor): |
|
td_data = dist_nstep_td_data( |
|
q_dist, target_q_dist, data['action'], target_q_action, data['reward'], data['done'], data['weight'] |
|
) |
|
loss, td_error_per_sample = dist_nstep_td_error( |
|
td_data, |
|
self._gamma, |
|
self._v_min, |
|
self._v_max, |
|
self._n_atom, |
|
nstep=self._nstep, |
|
value_gamma=value_gamma |
|
) |
|
else: |
|
act_num = len(q_dist) |
|
losses = [] |
|
td_error_per_samples = [] |
|
for i in range(act_num): |
|
td_data = dist_nstep_td_data( |
|
q_dist[i], target_q_dist[i], data['action'][i], target_q_action[i], data['reward'], data['done'], |
|
data['weight'] |
|
) |
|
td_loss, td_error_per_sample = dist_nstep_td_error( |
|
td_data, |
|
self._gamma, |
|
self._v_min, |
|
self._v_max, |
|
self._n_atom, |
|
nstep=self._nstep, |
|
value_gamma=value_gamma |
|
) |
|
losses.append(td_loss) |
|
td_error_per_samples.append(td_error_per_sample) |
|
loss = sum(losses) / (len(losses) + 1e-8) |
|
td_error_per_sample_mean = sum(td_error_per_samples) / (len(td_error_per_samples) + 1e-8) |
|
|
|
|
|
|
|
self._optimizer.zero_grad() |
|
loss.backward() |
|
self._optimizer.step() |
|
|
|
|
|
|
|
self._target_model.update(self._learn_model.state_dict()) |
|
return { |
|
'cur_lr': self._optimizer.defaults['lr'], |
|
'total_loss': loss.item(), |
|
'priority': td_error_per_sample_mean.abs().tolist(), |
|
} |
|
|