File size: 15,386 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 |
from typing import List, Dict, Any, Tuple, Union, Optional
from collections import namedtuple
import torch
import copy
from ding.torch_utils import RMSprop, to_device
from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \
v_nstep_td_data, v_nstep_td_error, get_nstep_return_data
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import timestep_collate, default_collate, default_decollate
from .qmix import QMIXPolicy
@POLICY_REGISTRY.register('madqn')
class MADQNPolicy(QMIXPolicy):
config = dict(
# (str) RL policy register name (refer to function "POLICY_REGISTRY").
type='madqn',
# (bool) Whether to use cuda for network.
cuda=True,
# (bool) Whether the RL algorithm is on-policy or off-policy.
on_policy=False,
# (bool) Whether use priority(priority sample, IS weight, update priority)
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
nstep=3,
learn=dict(
update_per_collect=20,
batch_size=32,
learning_rate=0.0005,
clip_value=100,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) Target network update momentum parameter.
# in [0, 1].
target_update_theta=0.008,
# (float) The discount factor for future rewards,
# in [0, 1].
discount_factor=0.99,
# (bool) Whether to use double DQN mechanism(target q for surpassing over estimation)
double_q=False,
weight_decay=1e-5,
),
collect=dict(
# (int) Only one of [n_sample, n_episode] shoule be set
n_episode=32,
# (int) Cut trajectories into pieces with length "unroll_len", the length of timesteps
# in each forward when training. In qmix, it is greater than 1 because there is RNN.
unroll_len=10,
),
eval=dict(),
other=dict(
eps=dict(
# (str) Type of epsilon decay
type='exp',
# (float) Start value for epsilon decay, in [0, 1].
# 0 means not use epsilon decay.
start=1,
# (float) Start value for epsilon decay, in [0, 1].
end=0.05,
# (int) Decay length(env step)
decay=50000,
),
replay_buffer=dict(
replay_buffer_size=5000,
# (int) The maximum reuse times of each data
max_reuse=1e+9,
max_staleness=1e+9,
),
),
)
def default_model(self) -> Tuple[str, List[str]]:
"""
Overview:
Return this algorithm default model setting for demonstration.
Returns:
- model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
"""
return 'madqn', ['ding.model.template.madqn']
def _init_learn(self) -> None:
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in QMIX"
self._optimizer_current = RMSprop(
params=self._model.current.parameters(),
lr=self._cfg.learn.learning_rate,
alpha=0.99,
eps=0.00001,
weight_decay=self._cfg.learn.weight_decay
)
self._optimizer_cooperation = RMSprop(
params=self._model.cooperation.parameters(),
lr=self._cfg.learn.learning_rate,
alpha=0.99,
eps=0.00001,
weight_decay=self._cfg.learn.weight_decay
)
self._gamma = self._cfg.learn.discount_factor
self._nstep = self._cfg.nstep
self._target_model = copy.deepcopy(self._model)
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='momentum',
update_kwargs={'theta': self._cfg.learn.target_update_theta}
)
self._target_model = model_wrap(
self._target_model,
wrapper_name='hidden_state',
state_num=self._cfg.learn.batch_size,
init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
)
self._learn_model = model_wrap(
self._model,
wrapper_name='hidden_state',
state_num=self._cfg.learn.batch_size,
init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
)
self._learn_model.reset()
self._target_model.reset()
def _data_preprocess_learn(self, data: List[Any]) -> dict:
r"""
Overview:
Preprocess the data to fit the required data format for learning
Arguments:
- data (:obj:`List[Dict[str, Any]]`): the data collected from collect function
Returns:
- data (:obj:`Dict[str, Any]`): the processed data, from \
[len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])}
"""
# data preprocess
data = timestep_collate(data)
if self._cuda:
data = to_device(data, self._device)
data['weight'] = data.get('weight', None)
data['done'] = data['done'].float()
return data
def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
Overview:
Forward and backward function of learn mode.
Arguments:
- data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
np.ndarray or dict/list combinations.
Returns:
- info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
recorded in text log and tensorboard, values are python scalar or a list of scalars.
ArgumentsKeys:
- necessary: ``obs``, ``next_obs``, ``action``, ``reward``, ``weight``, ``prev_state``, ``done``
ReturnsKeys:
- necessary: ``cur_lr``, ``total_loss``
- cur_lr (:obj:`float`): Current learning rate
- total_loss (:obj:`float`): The calculated loss
"""
data = self._data_preprocess_learn(data)
# ====================
# Q-mix forward
# ====================
self._learn_model.train()
self._target_model.train()
# for hidden_state plugin, we need to reset the main model and target model
self._learn_model.reset(state=data['prev_state'][0])
self._target_model.reset(state=data['prev_state'][0])
inputs = {'obs': data['obs'], 'action': data['action']}
total_q = self._learn_model.forward(inputs, single_step=False)['total_q']
if self._cfg.learn.double_q:
next_inputs = {'obs': data['next_obs']}
self._learn_model.reset(state=data['prev_state'][1])
logit_detach = self._learn_model.forward(next_inputs, single_step=False)['logit'].clone().detach()
next_inputs = {'obs': data['next_obs'], 'action': logit_detach.argmax(dim=-1)}
else:
next_inputs = {'obs': data['next_obs']}
with torch.no_grad():
target_total_q = self._target_model.forward(next_inputs, cooperation=True, single_step=False)['total_q']
if self._nstep == 1:
v_data = v_1step_td_data(total_q, target_total_q, data['reward'], data['done'], data['weight'])
loss, td_error_per_sample = v_1step_td_error(v_data, self._gamma)
# for visualization
with torch.no_grad():
if data['done'] is not None:
target_v = self._gamma * (1 - data['done']) * target_total_q + data['reward']
else:
target_v = self._gamma * target_total_q + data['reward']
else:
data['reward'] = data['reward'].permute(0, 2, 1).contiguous()
loss = []
td_error_per_sample = []
for t in range(self._cfg.collect.unroll_len):
v_data = v_nstep_td_data(
total_q[t], target_total_q[t], data['reward'][t], data['done'][t], data['weight'], self._gamma
)
# calculate v_nstep_td critic_loss
loss_i, td_error_per_sample_i = v_nstep_td_error(v_data, self._gamma, self._nstep)
loss.append(loss_i)
td_error_per_sample.append(td_error_per_sample_i)
loss = sum(loss) / (len(loss) + 1e-8)
td_error_per_sample = sum(td_error_per_sample) / (len(td_error_per_sample) + 1e-8)
self._optimizer_current.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(self._model.current.parameters(), self._cfg.learn.clip_value)
self._optimizer_current.step()
# cooperation
self._learn_model.reset(state=data['prev_state'][0])
self._target_model.reset(state=data['prev_state'][0])
cooperation_total_q = self._learn_model.forward(inputs, cooperation=True, single_step=False)['total_q']
next_inputs = {'obs': data['next_obs']}
with torch.no_grad():
cooperation_target_total_q = self._target_model.forward(
next_inputs, cooperation=True, single_step=False
)['total_q']
if self._nstep == 1:
v_data = v_1step_td_data(
cooperation_total_q, cooperation_target_total_q, data['reward'], data['done'], data['weight']
)
cooperation_loss, _ = v_1step_td_error(v_data, self._gamma)
else:
cooperation_loss_all = []
for t in range(self._cfg.collect.unroll_len):
v_data = v_nstep_td_data(
cooperation_total_q[t], cooperation_target_total_q[t], data['reward'][t], data['done'][t],
data['weight'], self._gamma
)
cooperation_loss, _ = v_nstep_td_error(v_data, self._gamma, self._nstep)
cooperation_loss_all.append(cooperation_loss)
cooperation_loss = sum(cooperation_loss_all) / (len(cooperation_loss_all) + 1e-8)
self._optimizer_cooperation.zero_grad()
cooperation_loss.backward()
cooperation_grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.cooperation.parameters(), self._cfg.learn.clip_value
)
self._optimizer_cooperation.step()
# =============
# after update
# =============
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer_current.defaults['lr'],
'total_loss': loss.item(),
'total_q': total_q.mean().item() / self._cfg.model.agent_num,
'target_total_q': target_total_q.mean().item() / self._cfg.model.agent_num,
'grad_norm': grad_norm,
'cooperation_grad_norm': cooperation_grad_norm,
'cooperation_loss': cooperation_loss.item(),
}
def _reset_learn(self, data_id: Optional[List[int]] = None) -> None:
r"""
Overview:
Reset learn model to the state indicated by data_id
Arguments:
- data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\
the model state to the state indicated by data_id
"""
self._learn_model.reset(data_id=data_id)
def _state_dict_learn(self) -> Dict[str, Any]:
r"""
Overview:
Return the state_dict of learn mode, usually including model and optimizer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
"""
return {
'model': self._learn_model.state_dict(),
'target_model': self._target_model.state_dict(),
'optimizer_current': self._optimizer_current.state_dict(),
'optimizer_cooperation': self._optimizer_cooperation.state_dict(),
}
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
"""
Overview:
Load the state_dict variable into policy learn mode.
Arguments:
- state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
.. tip::
If you want to only load some parts of model, you can simply set the ``strict`` argument in \
load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
complicated operation.
"""
self._learn_model.load_state_dict(state_dict['model'])
self._target_model.load_state_dict(state_dict['target_model'])
self._optimizer_current.load_state_dict(state_dict['optimizer_current'])
self._optimizer_cooperation.load_state_dict(state_dict['optimizer_cooperation'])
def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
r"""
Overview:
Generate dict type transition data from inputs.
Arguments:
- obs (:obj:`Any`): Env observation
- model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state']
- timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\
(here 'obs' indicates obs after env step).
Returns:
- transition (:obj:`dict`): Dict type transition data, including 'obs', 'next_obs', 'prev_state',\
'action', 'reward', 'done'
"""
transition = {
'obs': obs,
'next_obs': timestep.obs,
'prev_state': model_output['prev_state'],
'action': model_output['action'],
'reward': timestep.reward,
'done': timestep.done,
}
return transition
def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
r"""
Overview:
Get the train sample from trajectory.
Arguments:
- data (:obj:`list`): The trajectory's cache
Returns:
- samples (:obj:`dict`): The training samples generated
"""
if self._cfg.nstep == 1:
return get_train_sample(data, self._unroll_len)
else:
data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
return get_train_sample(data, self._unroll_len)
def _monitor_vars_learn(self) -> List[str]:
r"""
Overview:
Return variables' name if variables are to used in monitor.
Returns:
- vars (:obj:`List[str]`): Variables' name list.
"""
return [
'cur_lr', 'total_loss', 'total_q', 'target_total_q', 'grad_norm', 'target_reward_total_q',
'cooperation_grad_norm', 'cooperation_loss'
]
|