File size: 12,295 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 |
from typing import List, Dict, Any, Tuple, Union
from collections import namedtuple
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.torch_utils import Adam, to_device
from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, get_nstep_return_data
from ding.model import model_wrap
from ding.policy import Policy
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .common_utils import default_preprocess_learn
@POLICY_REGISTRY.register('bcq')
class BCQPolicy(Policy):
config = dict(
type='bcq',
# (bool) Whether to use cuda for network.
cuda=False,
# (bool type) priority: Determine whether to use priority in buffer sample.
# Default False in SAC.
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
# (int) Number of training samples(randomly collected) in replay buffer when training starts.
# Default 10000 in SAC.
random_collect_size=10000,
nstep=1,
model=dict(
# (List) Hidden list for actor network head.
actor_head_hidden_size=[400, 300],
# (List) Hidden list for critic network head.
critic_head_hidden_size=[400, 300],
# Max perturbation hyper-parameter for BCQ
phi=0.05,
),
learn=dict(
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=1,
# (int) Minibatch size for gradient descent.
batch_size=100,
# (float type) learning_rate_q: Learning rate for soft q network.
# Default to 3e-4.
# Please set to 1e-3, when model.value_network is True.
learning_rate_q=3e-4,
# (float type) learning_rate_policy: Learning rate for policy network.
# Default to 3e-4.
# Please set to 1e-3, when model.value_network is True.
learning_rate_policy=3e-4,
# (float type) learning_rate_vae: Learning rate for vae network.
# `learning_rate_value` should be initialized, when model.vae_network is True.
# Please set to 3e-4, when model.vae_network is True.
learning_rate_vae=3e-4,
# (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
# Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
# These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
# However, interaction with HalfCheetah always gets done with done is False,
# Since we inplace done==True with done==False to keep
# TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
# when the episode step is greater than max episode step.
ignore_done=False,
# (float type) target_theta: Used for soft update of the target network,
# aka. Interpolation factor in polyak averaging for target networks.
# Default to 0.005.
target_theta=0.005,
# (float) discount factor for the discounted sum of rewards, aka. gamma.
discount_factor=0.99,
lmbda=0.75,
# (float) Weight uniform initialization range in the last output layer
init_w=3e-3,
),
collect=dict(
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
eval=dict(),
other=dict(
replay_buffer=dict(
# (int type) replay_buffer_size: Max size of replay buffer.
replay_buffer_size=1000000,
# (int type) max_use: Max use times of one data in the buffer.
# Data will be removed once used for too many times.
# Default to infinite.
# max_use=256,
),
),
)
def default_model(self) -> Tuple[str, List[str]]:
return 'bcq', ['ding.model.template.bcq']
def _init_learn(self) -> None:
r"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init q, value and policy's optimizers, algorithm config, main and target models.
"""
# Init
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
self.lmbda = self._cfg.learn.lmbda
self.latent_dim = self._cfg.model.action_shape * 2
# Optimizers
self._optimizer_q = Adam(
self._model.critic.parameters(),
lr=self._cfg.learn.learning_rate_q,
)
self._optimizer_policy = Adam(
self._model.actor.parameters(),
lr=self._cfg.learn.learning_rate_policy,
)
self._optimizer_vae = Adam(
self._model.vae.parameters(),
lr=self._cfg.learn.learning_rate_vae,
)
# Algorithm config
self._gamma = self._cfg.learn.discount_factor
# Main and target models
self._target_model = copy.deepcopy(self._model)
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='momentum',
update_kwargs={'theta': self._cfg.learn.target_theta}
)
self._learn_model = model_wrap(self._model, wrapper_name='base')
self._learn_model.reset()
self._target_model.reset()
self._forward_learn_cnt = 0
def _forward_learn(self, data: dict) -> Dict[str, Any]:
loss_dict = {}
data = default_preprocess_learn(
data,
use_priority=self._priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=False
)
if len(data.get('action').shape) == 1:
data['action'] = data['action'].reshape(-1, 1)
if self._cuda:
data = to_device(data, self._device)
self._learn_model.train()
self._target_model.train()
obs = data['obs']
next_obs = data['next_obs']
reward = data['reward']
done = data['done']
batch_size = obs.shape[0]
# train_vae
vae_out = self._model.forward(data, mode='compute_vae')
recon, mean, log_std = vae_out['recons_action'], vae_out['mu'], vae_out['log_var']
recons_loss = F.mse_loss(recon, data['action'])
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_std - mean ** 2 - log_std.exp(), dim=1), dim=0)
loss_dict['recons_loss'] = recons_loss
loss_dict['kld_loss'] = kld_loss
vae_loss = recons_loss + 0.5 * kld_loss
loss_dict['vae_loss'] = vae_loss
self._optimizer_vae.zero_grad()
vae_loss.backward()
self._optimizer_vae.step()
# train_critic
q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
with torch.no_grad():
next_obs_rep = torch.repeat_interleave(next_obs, 10, 0)
z = torch.randn((next_obs_rep.shape[0], self.latent_dim)).to(self._device).clamp(-0.5, 0.5)
vae_action = self._model.vae.decode_with_obs(z, next_obs_rep)['reconstruction_action']
next_action = self._target_model.forward({
'obs': next_obs_rep,
'action': vae_action
}, mode='compute_actor')['action']
next_data = {'obs': next_obs_rep, 'action': next_action}
target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
# the value of a policy according to the maximum entropy objective
# find min one as target q value
target_q_value = self.lmbda * torch.min(target_q_value[0], target_q_value[1]) \
+ (1 - self.lmbda) * torch.max(target_q_value[0], target_q_value[1])
target_q_value = target_q_value.reshape(batch_size, -1).max(1)[0].reshape(-1, 1)
q_data0 = v_1step_td_data(q_value[0], target_q_value, reward, done, data['weight'])
loss_dict['critic_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma)
q_data1 = v_1step_td_data(q_value[1], target_q_value, reward, done, data['weight'])
loss_dict['twin_critic_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma)
td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2
self._optimizer_q.zero_grad()
(loss_dict['critic_loss'] + loss_dict['twin_critic_loss']).backward()
self._optimizer_q.step()
# train_policy
z = torch.randn((obs.shape[0], self.latent_dim)).to(self._device).clamp(-0.5, 0.5)
sample_action = self._model.vae.decode_with_obs(z, obs)['reconstruction_action']
input = {'obs': obs, 'action': sample_action}
perturbed_action = self._model.forward(input, mode='compute_actor')['action']
q_input = {'obs': obs, 'action': perturbed_action}
q = self._learn_model.forward(q_input, mode='compute_critic')['q_value'][0]
loss_dict['actor_loss'] = -q.mean()
self._optimizer_policy.zero_grad()
loss_dict['actor_loss'].backward()
self._optimizer_policy.step()
self._forward_learn_cnt += 1
self._target_model.update(self._learn_model.state_dict())
return {
'td_error': td_error_per_sample.detach().mean().item(),
'target_q_value': target_q_value.detach().mean().item(),
**loss_dict
}
def _monitor_vars_learn(self) -> List[str]:
return [
'td_error', 'target_q_value', 'critic_loss', 'twin_critic_loss', 'actor_loss', 'recons_loss', 'kld_loss',
'vae_loss'
]
def _state_dict_learn(self) -> Dict[str, Any]:
ret = {
'model': self._learn_model.state_dict(),
'target_model': self._target_model.state_dict(),
'optimizer_q': self._optimizer_q.state_dict(),
'optimizer_policy': self._optimizer_policy.state_dict(),
'optimizer_vae': self._optimizer_vae.state_dict(),
}
return ret
def _init_eval(self):
self._eval_model = model_wrap(self._model, wrapper_name='base')
self._eval_model.reset()
def _forward_eval(self, data: dict) -> Dict[str, Any]:
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
data = {'obs': data}
self._eval_model.eval()
with torch.no_grad():
output = self._eval_model.forward(data, mode='compute_eval')
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _init_collect(self) -> None:
self._unroll_len = self._cfg.collect.unroll_len
self._gamma = self._cfg.discount_factor # necessary for parallel
self._nstep = self._cfg.nstep # necessary for parallel
self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample')
self._collect_model.reset()
def _forward_collect(self, data: dict, **kwargs) -> dict:
pass
def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
pass
def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
r"""
Overview:
Get the trajectory and the n step return data, then sample from the n_step return data
Arguments:
- data (:obj:`list`): The trajectory's cache
Returns:
- samples (:obj:`dict`): The training samples generated
"""
data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
return get_train_sample(data, self._unroll_len)
|