File size: 6,988 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from typing import Dict, Any, List, Tuple
from collections import namedtuple
from easydict import EasyDict
import torch
import torch.nn.functional as F
from ding.model import model_wrap
from ding.torch_utils import to_device
from ding.utils.data import default_collate, default_decollate
from ding.utils import POLICY_REGISTRY
from .bc import BehaviourCloningPolicy
from ding.model.template.ebm import create_stochastic_optimizer
from ding.model.template.ebm import StochasticOptimizer, MCMC, AutoRegressiveDFO
from ding.torch_utils import unsqueeze_repeat
from ding.utils import EasyTimer
@POLICY_REGISTRY.register('ibc')
class IBCPolicy(BehaviourCloningPolicy):
r"""
Overview:
Implicit Behavior Cloning
https://arxiv.org/abs/2109.00137.pdf
.. note::
The code is adapted from the pytorch version of IBC https://github.com/kevinzakka/ibc,
which only supports the derivative-free optimization (dfo) variants.
This implementation moves a step forward and supports all variants of energy-based model
mentioned in the paper (dfo, autoregressive dfo, and mcmc).
"""
config = dict(
type='ibc',
cuda=False,
on_policy=False,
continuous=True,
model=dict(stochastic_optim=dict(type='mcmc', )),
learn=dict(
train_epoch=30,
batch_size=256,
optim=dict(
learning_rate=1e-5,
weight_decay=0.0,
beta1=0.9,
beta2=0.999,
),
),
eval=dict(evaluator=dict(eval_freq=10000, )),
)
def default_model(self) -> Tuple[str, List[str]]:
return 'ebm', ['ding.model.template.ebm']
def _init_learn(self):
self._timer = EasyTimer(cuda=self._cfg.cuda)
self._sync_timer = EasyTimer(cuda=self._cfg.cuda)
optim_cfg = self._cfg.learn.optim
self._optimizer = torch.optim.AdamW(
self._model.parameters(),
lr=optim_cfg.learning_rate,
weight_decay=optim_cfg.weight_decay,
betas=(optim_cfg.beta1, optim_cfg.beta2),
)
self._stochastic_optimizer: StochasticOptimizer = \
create_stochastic_optimizer(self._device, self._cfg.model.stochastic_optim)
self._learn_model = model_wrap(self._model, 'base')
self._learn_model.reset()
def _forward_learn(self, data):
with self._timer:
data = default_collate(data)
if self._cuda:
data = to_device(data, self._device)
self._learn_model.train()
loss_dict = dict()
# obs: (B, O)
# action: (B, A)
obs, action = data['obs'], data['action']
# When action/observation space is 1, the action/observation dimension will
# be squeezed in the first place, therefore unsqueeze there to make the data
# compatiable with the ibc pipeline.
if len(obs.shape) == 1:
obs = obs.unsqueeze(-1)
if len(action.shape) == 1:
action = action.unsqueeze(-1)
# N refers to the number of negative samples, i.e. self._stochastic_optimizer.inference_samples.
# (B, N, O), (B, N, A)
obs, negatives = self._stochastic_optimizer.sample(obs, self._learn_model)
# (B, N+1, A)
targets = torch.cat([action.unsqueeze(dim=1), negatives], dim=1)
# (B, N+1, O)
obs = torch.cat([obs[:, :1], obs], dim=1)
permutation = torch.rand(targets.shape[0], targets.shape[1]).argsort(dim=1)
targets = targets[torch.arange(targets.shape[0]).unsqueeze(-1), permutation]
# (B, )
ground_truth = (permutation == 0).nonzero()[:, 1].to(self._device)
# (B, N+1) for ebm
# (B, N+1, A) for autoregressive ebm
energy = self._learn_model.forward(obs, targets)
logits = -1.0 * energy
if isinstance(self._stochastic_optimizer, AutoRegressiveDFO):
# autoregressive case
# (B, A)
ground_truth = unsqueeze_repeat(ground_truth, logits.shape[-1], -1)
loss = F.cross_entropy(logits, ground_truth)
loss_dict['ebm_loss'] = loss.item()
if isinstance(self._stochastic_optimizer, MCMC):
grad_penalty = self._stochastic_optimizer.grad_penalty(obs, targets, self._learn_model)
loss += grad_penalty
loss_dict['grad_penalty'] = grad_penalty.item()
loss_dict['total_loss'] = loss.item()
self._optimizer.zero_grad()
loss.backward()
with self._sync_timer:
if self._cfg.multi_gpu:
self.sync_gradients(self._learn_model)
sync_time = self._sync_timer.value
self._optimizer.step()
total_time = self._timer.value
return {
'total_time': total_time,
'sync_time': sync_time,
**loss_dict,
}
def _monitor_vars_learn(self):
if isinstance(self._stochastic_optimizer, MCMC):
return ['total_loss', 'ebm_loss', 'grad_penalty', 'total_time', 'sync_time']
else:
return ['total_loss', 'ebm_loss', 'total_time', 'sync_time']
def _init_eval(self):
self._eval_model = model_wrap(self._model, wrapper_name='base')
self._eval_model.reset()
def _forward_eval(self, data: dict) -> dict:
tensor_input = isinstance(data, torch.Tensor)
if not tensor_input:
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._eval_model.eval()
output = self._stochastic_optimizer.infer(data, self._eval_model)
output = dict(action=output)
if self._cuda:
output = to_device(output, 'cpu')
if tensor_input:
return output
else:
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def set_statistic(self, statistics: EasyDict) -> None:
self._stochastic_optimizer.set_action_bounds(statistics.action_bounds)
# =================================================================== #
# Implicit Behavioral Cloning does not need `collect`-related functions
# =================================================================== #
def _init_collect(self):
raise NotImplementedError
def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
raise NotImplementedError
def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
raise NotImplementedError
def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
raise NotImplementedError
|