File size: 8,926 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
from typing import List, Dict, Any, Tuple, Union
from collections import namedtuple
import torch
import treetensor as ttorch
from ding.rl_utils import get_gae_with_default_last_value, get_train_sample
from ding.torch_utils import Adam, to_device
from ding.utils import POLICY_REGISTRY, split_data_generator
from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy
from .common_utils import default_preprocess_learn
@POLICY_REGISTRY.register('pg')
class PGPolicy(Policy):
r"""
Overview:
Policy class of Policy Gradient (REINFORCE) algorithm.
"""
config = dict(
# (string) RL policy register name (refer to function "register_policy").
type='pg',
# (bool) whether to use cuda for network.
cuda=False,
# (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
on_policy=True, # for pg strictly on policy algorithm, this line should not be modified by users
# (str) action space type: ['discrete', 'continuous']
action_space='discrete',
# (bool) whether to use deterministic action for evaluation.
deterministic_eval=True,
learn=dict(
# (int) the number of samples for one update.
batch_size=64,
# (float) the step size of one gradient descend.
learning_rate=0.001,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=0.01,
# (float) max grad norm value.
grad_norm=5,
# (bool) whether to ignore done signal for non-termination env.
ignore_done=False,
),
collect=dict(
# (int) collect n_sample data, train model n_iteration times
# n_episode=8,
# (int) trajectory unroll length
unroll_len=1,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) discount factor for future reward, defaults int [0, 1]
discount_factor=0.99,
collector=dict(get_train_sample=True),
),
eval=dict(),
)
def default_model(self) -> Tuple[str, List[str]]:
return 'pg', ['ding.model.template.pg']
def _init_learn(self) -> None:
r"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init the optimizer, algorithm config, main and target models.
"""
# Optimizer
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
self._entropy_weight = self._cfg.learn.entropy_weight
self._grad_norm = self._cfg.learn.grad_norm
self._learn_model = self._model # for compatibility
def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
Overview:
Forward and backward function of learn mode.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs','adv']
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
"""
data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False)
if self._cuda:
data = to_device(data, self._device)
self._model.train()
return_infos = []
for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
# forward
output = self._learn_model.forward(batch['obs'])
return_ = batch['return']
dist = output['dist']
# calculate PG loss
log_prob = dist.log_prob(batch['action'])
policy_loss = -(log_prob * return_).mean()
entropy_loss = -self._cfg.learn.entropy_weight * dist.entropy().mean()
total_loss = policy_loss + entropy_loss
# update
self._optimizer.zero_grad()
total_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
list(self._learn_model.parameters()),
max_norm=self._grad_norm,
)
self._optimizer.step()
# only record last updates information in logger
return_info = {
'cur_lr': self._optimizer.param_groups[0]['lr'],
'total_loss': total_loss.item(),
'policy_loss': policy_loss.item(),
'entropy_loss': entropy_loss.item(),
'return_abs_max': return_.abs().max().item(),
'grad_norm': grad_norm,
}
return_infos.append(return_info)
return return_infos
def _init_collect(self) -> None:
self._unroll_len = self._cfg.collect.unroll_len
self._gamma = self._cfg.collect.discount_factor
def _forward_collect(self, data: dict) -> dict:
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._model.eval()
with torch.no_grad():
output = self._model.forward(data)
output['action'] = output['dist'].sample()
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
r"""
Overview:
Generate dict type transition data from inputs.
Arguments:
- obs (:obj:`Any`): Env observation
- model_output (:obj:`dict`): Output of collect model, including at least ['action']
- timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
(here 'obs' indicates obs after env step).
Returns:
- transition (:obj:`dict`): Dict type transition data.
"""
return {
'obs': obs,
'action': model_output['action'],
'reward': timestep.reward,
'done': timestep.done,
}
def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
r"""
Overview:
Get the trajectory and the n step return data, then sample from the n_step return data
Arguments:
- data (:obj:`list`): The trajectory's buffer list
Returns:
- samples (:obj:`dict`): The training samples generated
"""
assert data[-1]['done'], "PG needs a complete epsiode"
if self._cfg.learn.ignore_done:
raise NotImplementedError
R = 0.
if isinstance(data, list):
for i in reversed(range(len(data))):
R = self._gamma * R + data[i]['reward']
data[i]['return'] = R
return get_train_sample(data, self._unroll_len)
elif isinstance(data, ttorch.Tensor):
data_size = data['done'].shape[0]
data['return'] = ttorch.torch.zeros(data_size)
for i in reversed(range(data_size)):
R = self._gamma * R + data['reward'][i]
data['return'][i] = R
return get_train_sample(data, self._unroll_len)
else:
raise ValueError
def _init_eval(self) -> None:
pass
def _forward_eval(self, data: dict) -> dict:
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._model.eval()
with torch.no_grad():
output = self._model.forward(data)
if self._cfg.deterministic_eval:
if self._cfg.action_space == 'discrete':
output['action'] = output['logit'].argmax(dim=-1)
elif self._cfg.action_space == 'continuous':
output['action'] = output['logit']['mu']
else:
raise KeyError("invalid action_space: {}".format(self._cfg.action_space))
else:
output['action'] = output['dist'].sample()
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _monitor_vars_learn(self) -> List[str]:
return super()._monitor_vars_learn() + ['policy_loss', 'entropy_loss', 'return_abs_max', 'grad_norm']
|