zjowowen's picture
init space
079c32c
raw
history blame
12.2 kB
from typing import List, Dict, Any, Tuple, Union
from collections import namedtuple
import torch
from ding.rl_utils import a2c_data, a2c_error, get_gae_with_default_last_value, get_train_sample, \
a2c_error_continuous
from ding.torch_utils import Adam, to_device
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY, split_data_generator
from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy
from .common_utils import default_preprocess_learn
@POLICY_REGISTRY.register('a2c')
class A2CPolicy(Policy):
r"""
Overview:
Policy class of A2C algorithm.
"""
config = dict(
# (string) RL policy register name (refer to function "register_policy").
type='a2c',
# (bool) Whether to use cuda for network.
cuda=False,
# (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
on_policy=True, # for a2c strictly on policy algorithm, this line should not be seen by users
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
# (str) Which kind of action space used in PPOPolicy, ['discrete', 'continuous']
action_space='discrete',
learn=dict(
# (int) for a2c, update_per_collect must be 1.
update_per_collect=1, # fixed value, this line should not be modified by users
batch_size=64,
learning_rate=0.001,
# (List[float])
betas=(0.9, 0.999),
# (float)
eps=1e-8,
# (float)
grad_norm=0.5,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) loss weight of the value network, the weight of policy network is set to 1
value_weight=0.5,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=0.01,
# (bool) Whether to normalize advantage. Default to False.
adv_norm=False,
ignore_done=False,
),
collect=dict(
# (int) collect n_sample data, train model n_iteration times
# n_sample=80,
unroll_len=1,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) discount factor for future reward, defaults int [0, 1]
discount_factor=0.9,
# (float) the trade-off factor lambda to balance 1step td and mc
gae_lambda=0.95,
),
eval=dict(),
)
def default_model(self) -> Tuple[str, List[str]]:
return 'vac', ['ding.model.template.vac']
def _init_learn(self) -> None:
r"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init the optimizer, algorithm config, main and target models.
"""
assert self._cfg.action_space in ["continuous", "discrete"]
# Optimizer
self._optimizer = Adam(
self._model.parameters(),
lr=self._cfg.learn.learning_rate,
betas=self._cfg.learn.betas,
eps=self._cfg.learn.eps
)
# Algorithm config
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
self._value_weight = self._cfg.learn.value_weight
self._entropy_weight = self._cfg.learn.entropy_weight
self._adv_norm = self._cfg.learn.adv_norm
self._grad_norm = self._cfg.learn.grad_norm
# Main and target models
self._learn_model = model_wrap(self._model, wrapper_name='base')
self._learn_model.reset()
def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
Overview:
Forward and backward function of learn mode.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs','adv']
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
"""
data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False)
if self._cuda:
data = to_device(data, self._device)
self._learn_model.train()
for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
# forward
output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic')
adv = batch['adv']
return_ = batch['value'] + adv
if self._adv_norm:
# norm adv in total train_batch
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
error_data = a2c_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight'])
# Calculate A2C loss
if self._action_space == 'continuous':
a2c_loss = a2c_error_continuous(error_data)
elif self._action_space == 'discrete':
a2c_loss = a2c_error(error_data)
wv, we = self._value_weight, self._entropy_weight
total_loss = a2c_loss.policy_loss + wv * a2c_loss.value_loss - we * a2c_loss.entropy_loss
# ====================
# A2C-learning update
# ====================
self._optimizer.zero_grad()
total_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
list(self._learn_model.parameters()),
max_norm=self._grad_norm,
)
self._optimizer.step()
# =============
# after update
# =============
# only record last updates information in logger
return {
'cur_lr': self._optimizer.param_groups[0]['lr'],
'total_loss': total_loss.item(),
'policy_loss': a2c_loss.policy_loss.item(),
'value_loss': a2c_loss.value_loss.item(),
'entropy_loss': a2c_loss.entropy_loss.item(),
'adv_abs_max': adv.abs().max().item(),
'grad_norm': grad_norm,
}
def _state_dict_learn(self) -> Dict[str, Any]:
return {
'model': self._learn_model.state_dict(),
'optimizer': self._optimizer.state_dict(),
}
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
self._learn_model.load_state_dict(state_dict['model'])
self._optimizer.load_state_dict(state_dict['optimizer'])
def _init_collect(self) -> None:
r"""
Overview:
Collect mode init method. Called by ``self.__init__``.
Init traj and unroll length, collect model.
"""
assert self._cfg.action_space in ["continuous", "discrete"]
self._unroll_len = self._cfg.collect.unroll_len
self._action_space = self._cfg.action_space
if self._action_space == 'continuous':
self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample')
elif self._action_space == 'discrete':
self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
self._collect_model.reset()
# Algorithm
self._gamma = self._cfg.collect.discount_factor
self._gae_lambda = self._cfg.collect.gae_lambda
def _forward_collect(self, data: dict) -> dict:
r"""
Overview:
Forward function of collect mode.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
Returns:
- output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
ReturnsKeys
- necessary: ``action``
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._collect_model.eval()
with torch.no_grad():
output = self._collect_model.forward(data, mode='compute_actor_critic')
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
r"""
Overview:
Generate dict type transition data from inputs.
Arguments:
- obs (:obj:`Any`): Env observation
- model_output (:obj:`dict`): Output of collect model, including at least ['action']
- timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
(here 'obs' indicates obs after env step).
Returns:
- transition (:obj:`dict`): Dict type transition data.
"""
transition = {
'obs': obs,
'next_obs': timestep.obs,
'action': model_output['action'],
'value': model_output['value'],
'reward': timestep.reward,
'done': timestep.done,
}
return transition
def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
r"""
Overview:
Get the trajectory and the n step return data, then sample from the n_step return data
Arguments:
- data (:obj:`list`): The trajectory's buffer list
Returns:
- samples (:obj:`dict`): The training samples generated
"""
data = get_gae_with_default_last_value(
data,
data[-1]['done'],
gamma=self._gamma,
gae_lambda=self._gae_lambda,
cuda=self._cuda,
)
return get_train_sample(data, self._unroll_len)
def _init_eval(self) -> None:
r"""
Overview:
Evaluate mode init method. Called by ``self.__init__``.
Init eval model with argmax strategy.
"""
assert self._cfg.action_space in ["continuous", "discrete"]
self._action_space = self._cfg.action_space
if self._action_space == 'continuous':
self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample')
elif self._action_space == 'discrete':
self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
self._eval_model.reset()
def _forward_eval(self, data: dict) -> dict:
r"""
Overview:
Forward function of eval mode, similar to ``self._forward_collect``.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
Returns:
- output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
ReturnsKeys
- necessary: ``action``
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._eval_model.eval()
with torch.no_grad():
output = self._eval_model.forward(data, mode='compute_actor')
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _monitor_vars_learn(self) -> List[str]:
return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'entropy_loss', 'adv_abs_max', 'grad_norm']