File size: 16,184 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 |
from typing import List, Dict, Any, Tuple, Union, Callable, Optional
from collections import namedtuple
from easydict import EasyDict
import copy
import random
import numpy as np
import torch
import treetensor.torch as ttorch
from torch.optim import AdamW
from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, gae, gae_data, ppo_error_continuous, \
get_gae, ppo_policy_error_continuous, ArgmaxSampler, MultinomialSampler, ReparameterizationSampler, MuSampler, \
HybridStochasticSampler, HybridDeterminsticSampler, value_transform, value_inv_transform, symlog, inv_symlog
from ding.utils import POLICY_REGISTRY, RunningMeanStd
@POLICY_REGISTRY.register('ppof')
class PPOFPolicy:
config = dict(
type='ppo',
on_policy=True,
cuda=True,
action_space='discrete',
discount_factor=0.99,
gae_lambda=0.95,
# learn
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
# learningrate scheduler, which the format is (10000, 0.1)
lr_scheduler=None,
weight_decay=0,
value_weight=0.5,
entropy_weight=0.01,
clip_ratio=0.2,
adv_norm=True,
value_norm='baseline',
ppo_param_init=True,
grad_norm=0.5,
# collect
n_sample=128,
unroll_len=1,
# eval
deterministic_eval=True,
# model
model=dict(),
)
mode = ['learn', 'collect', 'eval']
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
@classmethod
def default_model(cls: type) -> Callable:
from .model import PPOFModel
return PPOFModel
def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[str] = None) -> None:
self._cfg = cfg
if model is None:
self._model = self.default_model()
else:
self._model = model
if self._cfg.cuda and torch.cuda.is_available():
self._device = 'cuda'
self._model.cuda()
else:
self._device = 'cpu'
assert self._cfg.action_space in ["continuous", "discrete", "hybrid", 'multi_discrete']
self._action_space = self._cfg.action_space
if self._cfg.ppo_param_init:
self._model_param_init()
if enable_mode is None:
enable_mode = self.mode
self.enable_mode = enable_mode
if 'learn' in enable_mode:
self._optimizer = AdamW(
self._model.parameters(),
lr=self._cfg.learning_rate,
weight_decay=self._cfg.weight_decay,
)
# define linear lr scheduler
if self._cfg.lr_scheduler is not None:
epoch_num, min_lr_lambda = self._cfg.lr_scheduler
self._lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self._optimizer,
lr_lambda=lambda epoch: max(1.0 - epoch * (1.0 - min_lr_lambda) / epoch_num, min_lr_lambda)
)
if self._cfg.value_norm:
self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
if 'collect' in enable_mode:
if self._action_space == 'discrete':
self._collect_sampler = MultinomialSampler()
elif self._action_space == 'continuous':
self._collect_sampler = ReparameterizationSampler()
elif self._action_space == 'hybrid':
self._collect_sampler = HybridStochasticSampler()
if 'eval' in enable_mode:
if self._action_space == 'discrete':
if self._cfg.deterministic_eval:
self._eval_sampler = ArgmaxSampler()
else:
self._eval_sampler = MultinomialSampler()
elif self._action_space == 'continuous':
if self._cfg.deterministic_eval:
self._eval_sampler = MuSampler()
else:
self._eval_sampler = ReparameterizationSampler()
elif self._action_space == 'hybrid':
if self._cfg.deterministic_eval:
self._eval_sampler = HybridDeterminsticSampler()
else:
self._eval_sampler = HybridStochasticSampler()
# for compatibility
self.learn_mode = self
self.collect_mode = self
self.eval_mode = self
def _model_param_init(self):
for n, m in self._model.named_modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
if self._action_space in ['continuous', 'hybrid']:
for m in list(self._model.critic.modules()) + list(self._model.actor.modules()):
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# init log sigma
if self._action_space == 'continuous':
torch.nn.init.constant_(self._model.actor_head.log_sigma_param, -0.5)
for m in self._model.actor_head.mu.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
elif self._action_space == 'hybrid': # actor_head[1]: ReparameterizationHead, for action_args
if hasattr(self._model.actor_head[1], 'log_sigma_param'):
torch.nn.init.constant_(self._model.actor_head[1].log_sigma_param, -0.5)
for m in self._model.actor_head[1].mu.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
return_infos = []
self._model.train()
bs = self._cfg.batch_size
data = data[:self._cfg.n_sample // bs * bs] # rounding
# outer training loop
for epoch in range(self._cfg.epoch_per_collect):
# recompute adv
with torch.no_grad():
# get the value dictionary
# In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred'
value = self._model.compute_critic(data.obs)
next_value = self._model.compute_critic(data.next_obs)
reward = data.reward
assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline'],\
'Not supported value normalization! Value normalization supported: \
popart, value rescale, symlog, baseline'
if self._cfg.value_norm == 'popart':
unnormalized_value = value['unnormalized_pred']
unnormalized_next_value = value['unnormalized_pred']
mu = self._model.critic_head.popart.mu
sigma = self._model.critic_head.popart.sigma
reward = (reward - mu) / sigma
value = value['pred']
next_value = next_value['pred']
elif self._cfg.value_norm == 'value_rescale':
value = value_inv_transform(value['pred'])
next_value = value_inv_transform(next_value['pred'])
elif self._cfg.value_norm == 'symlog':
value = inv_symlog(value['pred'])
next_value = inv_symlog(next_value['pred'])
elif self._cfg.value_norm == 'baseline':
value = value['pred'] * self._running_mean_std.std
next_value = next_value['pred'] * self._running_mean_std.std
traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
adv_data = gae_data(value, next_value, reward, data.done, traj_flag)
data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda)
unnormalized_returns = value + data.adv # In popart, this return is normalized
if self._cfg.value_norm == 'popart':
self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1))
elif self._cfg.value_norm == 'value_rescale':
value = value_transform(value)
unnormalized_returns = value_transform(unnormalized_returns)
elif self._cfg.value_norm == 'symlog':
value = symlog(value)
unnormalized_returns = symlog(unnormalized_returns)
elif self._cfg.value_norm == 'baseline':
value /= self._running_mean_std.std
unnormalized_returns /= self._running_mean_std.std
self._running_mean_std.update(unnormalized_returns.cpu().numpy())
data.value = value
data.return_ = unnormalized_returns
# inner training loop
split_data = ttorch.split(data, self._cfg.batch_size)
random.shuffle(list(split_data))
for batch in split_data:
output = self._model.compute_actor_critic(batch.obs)
adv = batch.adv
if self._cfg.adv_norm:
# Normalize advantage in a train_batch
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
# Calculate ppo error
if self._action_space == 'continuous':
ppo_batch = ppo_data(
output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None
)
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._cfg.clip_ratio)
elif self._action_space == 'discrete':
ppo_batch = ppo_data(
output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None
)
ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
elif self._action_space == 'hybrid':
# discrete part (discrete policy loss and entropy loss)
ppo_discrete_batch = ppo_policy_data(
output.logit.action_type, batch.logit.action_type, batch.action.action_type, adv, None
)
ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_batch, self._cfg.clip_ratio)
# continuous part (continuous policy loss and entropy loss, value loss)
ppo_continuous_batch = ppo_data(
output.logit.action_args, batch.logit.action_args, batch.action.action_args, output.value,
batch.value, adv, batch.return_, None
)
ppo_continuous_loss, ppo_continuous_info = ppo_error_continuous(
ppo_continuous_batch, self._cfg.clip_ratio
)
# sum discrete and continuous loss
ppo_loss = type(ppo_continuous_loss)(
ppo_continuous_loss.policy_loss + ppo_discrete_loss.policy_loss, ppo_continuous_loss.value_loss,
ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss
)
ppo_info = type(ppo_continuous_info)(
max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl),
max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
)
wv, we = self._cfg.value_weight, self._cfg.entropy_weight
total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
self._optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.grad_norm)
self._optimizer.step()
return_info = {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': total_loss.item(),
'policy_loss': ppo_loss.policy_loss.item(),
'value_loss': ppo_loss.value_loss.item(),
'entropy_loss': ppo_loss.entropy_loss.item(),
'adv_max': adv.max().item(),
'adv_mean': adv.mean().item(),
'value_mean': output.value.mean().item(),
'value_max': output.value.max().item(),
'approx_kl': ppo_info.approx_kl,
'clipfrac': ppo_info.clipfrac,
}
if self._action_space == 'continuous':
return_info.update(
{
'action': batch.action.float().mean().item(),
'mu_mean': output.logit.mu.mean().item(),
'sigma_mean': output.logit.sigma.mean().item(),
}
)
elif self._action_space == 'hybrid':
return_info.update(
{
'action': batch.action.action_args.float().mean().item(),
'mu_mean': output.logit.action_args.mu.mean().item(),
'sigma_mean': output.logit.action_args.sigma.mean().item(),
}
)
return_infos.append(return_info)
if self._cfg.lr_scheduler is not None:
self._lr_scheduler.step()
return return_infos
def state_dict(self) -> Dict[str, Any]:
state_dict = {
'model': self._model.state_dict(),
}
if 'learn' in self.enable_mode:
state_dict['optimizer'] = self._optimizer.state_dict()
return state_dict
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._model.load_state_dict(state_dict['model'])
if 'learn' in self.enable_mode:
self._optimizer.load_state_dict(state_dict['optimizer'])
def collect(self, data: ttorch.Tensor) -> ttorch.Tensor:
self._model.eval()
with torch.no_grad():
output = self._model.compute_actor_critic(data)
action = self._collect_sampler(output.logit)
output.action = action
return output
def process_transition(self, obs: ttorch.Tensor, inference_output: dict, timestep: namedtuple) -> ttorch.Tensor:
return ttorch.as_tensor(
{
'obs': obs,
'next_obs': timestep.obs,
'action': inference_output.action,
'logit': inference_output.logit,
'value': inference_output.value,
'reward': timestep.reward,
'done': timestep.done,
}
)
def eval(self, data: ttorch.Tensor) -> ttorch.Tensor:
self._model.eval()
with torch.no_grad():
logit = self._model.compute_actor(data)
action = self._eval_sampler(logit)
return ttorch.as_tensor({'logit': logit, 'action': action})
def monitor_vars(self) -> List[str]:
variables = [
'cur_lr',
'policy_loss',
'value_loss',
'entropy_loss',
'adv_max',
'adv_mean',
'approx_kl',
'clipfrac',
'value_max',
'value_mean',
]
if self._action_space in ['action', 'mu_mean', 'sigma_mean']:
variables += ['mu_mean', 'sigma_mean', 'action']
return variables
def reset(self, env_id_list: Optional[List[int]] = None) -> None:
pass
|