|
from typing import List, Dict, Any, Tuple, Union |
|
import torch |
|
|
|
from ding.policy import PPOPolicy, PPOOffPolicy |
|
from ding.rl_utils import ppo_data, ppo_error, gae, gae_data |
|
from ding.utils import POLICY_REGISTRY, split_data_generator |
|
from ding.torch_utils import to_device |
|
from ding.policy.common_utils import default_preprocess_learn |
|
|
|
|
|
@POLICY_REGISTRY.register('md_ppo') |
|
class MultiDiscretePPOPolicy(PPOPolicy): |
|
r""" |
|
Overview: |
|
Policy class of Multi-discrete action space PPO algorithm. |
|
""" |
|
|
|
def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
r""" |
|
Overview: |
|
Forward and backward function of learn mode. |
|
Arguments: |
|
- data (:obj:`dict`): Dict type data |
|
Returns: |
|
- info_dict (:obj:`Dict[str, Any]`): |
|
Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \ |
|
adv_max, adv_mean, value_max, value_mean, approx_kl, clipfrac |
|
""" |
|
data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False) |
|
if self._cuda: |
|
data = to_device(data, self._device) |
|
|
|
|
|
|
|
return_infos = [] |
|
self._learn_model.train() |
|
|
|
for epoch in range(self._cfg.learn.epoch_per_collect): |
|
if self._recompute_adv: |
|
with torch.no_grad(): |
|
value = self._learn_model.forward(data['obs'], mode='compute_critic')['value'] |
|
next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value'] |
|
if self._value_norm: |
|
value *= self._running_mean_std.std |
|
next_value *= self._running_mean_std.std |
|
|
|
compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], data['traj_flag']) |
|
|
|
data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda) |
|
|
|
unnormalized_returns = value + data['adv'] |
|
|
|
if self._value_norm: |
|
data['value'] = value / self._running_mean_std.std |
|
data['return'] = unnormalized_returns / self._running_mean_std.std |
|
self._running_mean_std.update(unnormalized_returns.cpu().numpy()) |
|
else: |
|
data['value'] = value |
|
data['return'] = unnormalized_returns |
|
|
|
else: |
|
if self._value_norm: |
|
unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std |
|
data['return'] = unnormalized_return / self._running_mean_std.std |
|
self._running_mean_std.update(unnormalized_return.cpu().numpy()) |
|
else: |
|
data['return'] = data['adv'] + data['value'] |
|
|
|
for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True): |
|
output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic') |
|
adv = batch['adv'] |
|
if self._adv_norm: |
|
|
|
adv = (adv - adv.mean()) / (adv.std() + 1e-8) |
|
|
|
|
|
loss_list = [] |
|
info_list = [] |
|
action_num = len(batch['action']) |
|
for i in range(action_num): |
|
ppo_batch = ppo_data( |
|
output['logit'][i], batch['logit'][i], batch['action'][i], output['value'], batch['value'], adv, |
|
batch['return'], batch['weight'] |
|
) |
|
ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio) |
|
loss_list.append(ppo_loss) |
|
info_list.append(ppo_info) |
|
avg_policy_loss = sum([item.policy_loss for item in loss_list]) / action_num |
|
avg_value_loss = sum([item.value_loss for item in loss_list]) / action_num |
|
avg_entropy_loss = sum([item.entropy_loss for item in loss_list]) / action_num |
|
avg_approx_kl = sum([item.approx_kl for item in info_list]) / action_num |
|
avg_clipfrac = sum([item.clipfrac for item in info_list]) / action_num |
|
|
|
wv, we = self._value_weight, self._entropy_weight |
|
total_loss = avg_policy_loss + wv * avg_value_loss - we * avg_entropy_loss |
|
|
|
self._optimizer.zero_grad() |
|
total_loss.backward() |
|
self._optimizer.step() |
|
|
|
return_info = { |
|
'cur_lr': self._optimizer.defaults['lr'], |
|
'total_loss': total_loss.item(), |
|
'policy_loss': avg_policy_loss.item(), |
|
'value_loss': avg_value_loss.item(), |
|
'entropy_loss': avg_entropy_loss.item(), |
|
'adv_max': adv.max().item(), |
|
'adv_mean': adv.mean().item(), |
|
'value_mean': output['value'].mean().item(), |
|
'value_max': output['value'].max().item(), |
|
'approx_kl': avg_approx_kl, |
|
'clipfrac': avg_clipfrac, |
|
} |
|
return_infos.append(return_info) |
|
return return_infos |
|
|
|
|
|
@POLICY_REGISTRY.register('md_ppo_offpolicy') |
|
class MultiDiscretePPOOffPolicy(PPOOffPolicy): |
|
r""" |
|
Overview: |
|
Policy class of Multi-discrete action space off-policy PPO algorithm. |
|
""" |
|
|
|
def _forward_learn(self, data: dict) -> Dict[str, Any]: |
|
r""" |
|
Overview: |
|
Forward and backward function of learn mode. |
|
Arguments: |
|
- data (:obj:`dict`): Dict type data |
|
Returns: |
|
- info_dict (:obj:`Dict[str, Any]`): |
|
Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \ |
|
adv_abs_max, approx_kl, clipfrac |
|
""" |
|
assert not self._nstep_return |
|
data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=self._nstep_return) |
|
if self._cuda: |
|
data = to_device(data, self._device) |
|
|
|
|
|
|
|
|
|
self._learn_model.train() |
|
|
|
output = self._learn_model.forward(data['obs'], mode='compute_actor_critic') |
|
adv = data['adv'] |
|
return_ = data['value'] + adv |
|
if self._adv_norm: |
|
|
|
adv = (adv - adv.mean()) / (adv.std() + 1e-8) |
|
|
|
loss_list = [] |
|
info_list = [] |
|
action_num = len(data['action']) |
|
for i in range(action_num): |
|
ppodata = ppo_data( |
|
output['logit'][i], data['logit'][i], data['action'][i], output['value'], data['value'], adv, return_, |
|
data['weight'] |
|
) |
|
ppo_loss, ppo_info = ppo_error(ppodata, self._clip_ratio) |
|
loss_list.append(ppo_loss) |
|
info_list.append(ppo_info) |
|
avg_policy_loss = sum([item.policy_loss for item in loss_list]) / action_num |
|
avg_value_loss = sum([item.value_loss for item in loss_list]) / action_num |
|
avg_entropy_loss = sum([item.entropy_loss for item in loss_list]) / action_num |
|
avg_approx_kl = sum([item.approx_kl for item in info_list]) / action_num |
|
avg_clipfrac = sum([item.clipfrac for item in info_list]) / action_num |
|
|
|
wv, we = self._value_weight, self._entropy_weight |
|
total_loss = avg_policy_loss + wv * avg_value_loss - we * avg_entropy_loss |
|
|
|
|
|
|
|
|
|
self._optimizer.zero_grad() |
|
total_loss.backward() |
|
self._optimizer.step() |
|
return { |
|
'cur_lr': self._optimizer.defaults['lr'], |
|
'total_loss': total_loss.item(), |
|
'policy_loss': avg_policy_loss, |
|
'value_loss': avg_value_loss, |
|
'entropy_loss': avg_entropy_loss, |
|
'adv_abs_max': adv.abs().max().item(), |
|
'approx_kl': avg_approx_kl, |
|
'clipfrac': avg_clipfrac, |
|
} |
|
|