|
""" |
|
Vanilla DFO and EBM are adapted from https://github.com/kevinzakka/ibc. |
|
MCMC is adapted from https://github.com/google-research/ibc. |
|
""" |
|
from typing import Callable, Tuple |
|
from functools import wraps |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from abc import ABC, abstractmethod |
|
|
|
from ding.utils import MODEL_REGISTRY, STOCHASTIC_OPTIMIZER_REGISTRY |
|
from ding.torch_utils import unsqueeze_repeat |
|
from ding.model.wrapper import IModelWrapper |
|
from ding.model.common import RegressionHead |
|
|
|
|
|
def create_stochastic_optimizer(device: str, stochastic_optimizer_config: dict): |
|
""" |
|
Overview: |
|
Create stochastic optimizer. |
|
Arguments: |
|
- device (:obj:`str`): Device. |
|
- stochastic_optimizer_config (:obj:`dict`): Stochastic optimizer config. |
|
""" |
|
return STOCHASTIC_OPTIMIZER_REGISTRY.build( |
|
stochastic_optimizer_config.pop("type"), device=device, **stochastic_optimizer_config |
|
) |
|
|
|
|
|
def no_ebm_grad(): |
|
"""Wrapper that disables energy based model gradients""" |
|
|
|
def ebm_disable_grad_wrapper(func: Callable): |
|
|
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
ebm = args[-1] |
|
assert isinstance(ebm, (IModelWrapper, nn.Module)),\ |
|
'Make sure ebm is the last positional arguments.' |
|
ebm.requires_grad_(False) |
|
result = func(*args, **kwargs) |
|
ebm.requires_grad_(True) |
|
return result |
|
|
|
return wrapper |
|
|
|
return ebm_disable_grad_wrapper |
|
|
|
|
|
class StochasticOptimizer(ABC): |
|
""" |
|
Overview: |
|
Base class for stochastic optimizers. |
|
Interface: |
|
``__init__``, ``_sample``, ``_get_best_action_sample``, ``set_action_bounds``, ``sample``, ``infer`` |
|
""" |
|
|
|
def _sample(self, obs: torch.Tensor, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Drawing action samples from the uniform random distribution \ |
|
and tiling observations to the same shape as action samples. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observation. |
|
- num_samples (:obj:`int`): The number of negative samples. |
|
Returns: |
|
- tiled_obs (:obj:`torch.Tensor`): Observations tiled. |
|
- action (:obj:`torch.Tensor`): Action sampled. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. |
|
- num_samples (:obj:`int`): :math:`N`. |
|
- tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. |
|
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`. |
|
Examples: |
|
>>> obs = torch.randn(2, 4) |
|
>>> opt = StochasticOptimizer() |
|
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) |
|
>>> tiled_obs, action = opt._sample(obs, 8) |
|
""" |
|
size = (obs.shape[0], num_samples, self.action_bounds.shape[1]) |
|
low, high = self.action_bounds[0, :], self.action_bounds[1, :] |
|
action_samples = low + (high - low) * torch.rand(size).to(self.device) |
|
tiled_obs = unsqueeze_repeat(obs, num_samples, 1) |
|
return tiled_obs, action_samples |
|
|
|
@staticmethod |
|
@torch.no_grad() |
|
def _get_best_action_sample(obs: torch.Tensor, action_samples: torch.Tensor, ebm: nn.Module): |
|
""" |
|
Overview: |
|
Return one action for each batch with highest probability (lowest energy). |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observation. |
|
- action_samples (:obj:`torch.Tensor`): Action from uniform distributions. |
|
Returns: |
|
- best_action_samples (:obj:`torch.Tensor`): Best action. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. |
|
- action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. |
|
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. |
|
Examples: |
|
>>> obs = torch.randn(2, 4) |
|
>>> action_samples = torch.randn(2, 8, 5) |
|
>>> ebm = EBM(4, 5) |
|
>>> opt = StochasticOptimizer() |
|
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) |
|
>>> best_action_samples = opt._get_best_action_sample(obs, action_samples, ebm) |
|
""" |
|
|
|
energies = ebm.forward(obs, action_samples) |
|
probs = F.softmax(-1.0 * energies, dim=-1) |
|
|
|
best_idxs = probs.argmax(dim=-1) |
|
return action_samples[torch.arange(action_samples.size(0)), best_idxs] |
|
|
|
def set_action_bounds(self, action_bounds: np.ndarray): |
|
""" |
|
Overview: |
|
Set action bounds calculated from the dataset statistics. |
|
Arguments: |
|
- action_bounds (:obj:`np.ndarray`): Array of shape (2, A), \ |
|
where action_bounds[0] is lower bound and action_bounds[1] is upper bound. |
|
Returns: |
|
- action_bounds (:obj:`torch.Tensor`): Action bounds. |
|
Shapes: |
|
- action_bounds (:obj:`np.ndarray`): :math:`(2, A)`. |
|
- action_bounds (:obj:`torch.Tensor`): :math:`(2, A)`. |
|
Examples: |
|
>>> opt = StochasticOptimizer() |
|
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) |
|
""" |
|
self.action_bounds = torch.as_tensor(action_bounds, dtype=torch.float32).to(self.device) |
|
|
|
@abstractmethod |
|
def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Create tiled observations and sample counter-negatives for InfoNCE loss. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observations. |
|
- ebm (:obj:`torch.nn.Module`): Energy based model. |
|
Returns: |
|
- tiled_obs (:obj:`torch.Tensor`): Tiled observations. |
|
- action (:obj:`torch.Tensor`): Actions. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. |
|
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. |
|
- tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. |
|
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`. |
|
|
|
.. note:: In the case of derivative-free optimization, this function will simply call _sample. |
|
""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Optimize for the best action conditioned on the current observation. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observations. |
|
- ebm (:obj:`torch.nn.Module`): Energy based model. |
|
Returns: |
|
- best_action_samples (:obj:`torch.Tensor`): Best actions. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. |
|
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. |
|
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
@STOCHASTIC_OPTIMIZER_REGISTRY.register('dfo') |
|
class DFO(StochasticOptimizer): |
|
""" |
|
Overview: |
|
Derivative-Free Optimizer in paper Implicit Behavioral Cloning. |
|
https://arxiv.org/abs/2109.00137 |
|
Interface: |
|
``init``, ``sample``, ``infer`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
noise_scale: float = 0.33, |
|
noise_shrink: float = 0.5, |
|
iters: int = 3, |
|
train_samples: int = 8, |
|
inference_samples: int = 16384, |
|
device: str = 'cpu', |
|
): |
|
""" |
|
Overview: |
|
Initialize the Derivative-Free Optimizer |
|
Arguments: |
|
- noise_scale (:obj:`float`): Initial noise scale. |
|
- noise_shrink (:obj:`float`): Noise scale shrink rate. |
|
- iters (:obj:`int`): Number of iterations. |
|
- train_samples (:obj:`int`): Number of samples for training. |
|
- inference_samples (:obj:`int`): Number of samples for inference. |
|
- device (:obj:`str`): Device. |
|
""" |
|
self.action_bounds = None |
|
self.noise_scale = noise_scale |
|
self.noise_shrink = noise_shrink |
|
self.iters = iters |
|
self.train_samples = train_samples |
|
self.inference_samples = inference_samples |
|
self.device = device |
|
|
|
def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Drawing action samples from the uniform random distribution \ |
|
and tiling observations to the same shape as action samples. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observations. |
|
- ebm (:obj:`torch.nn.Module`): Energy based model. |
|
Returns: |
|
- tiled_obs (:obj:`torch.Tensor`): Tiled observation. |
|
- action_samples (:obj:`torch.Tensor`): Action samples. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. |
|
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. |
|
- tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. |
|
- action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. |
|
Examples: |
|
>>> obs = torch.randn(2, 4) |
|
>>> ebm = EBM(4, 5) |
|
>>> opt = DFO() |
|
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) |
|
>>> tiled_obs, action_samples = opt.sample(obs, ebm) |
|
""" |
|
return self._sample(obs, self.train_samples) |
|
|
|
@torch.no_grad() |
|
def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Optimize for the best action conditioned on the current observation. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observations. |
|
- ebm (:obj:`torch.nn.Module`): Energy based model. |
|
Returns: |
|
- best_action_samples (:obj:`torch.Tensor`): Actions. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. |
|
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. |
|
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. |
|
Examples: |
|
>>> obs = torch.randn(2, 4) |
|
>>> ebm = EBM(4, 5) |
|
>>> opt = DFO() |
|
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) |
|
>>> best_action_samples = opt.infer(obs, ebm) |
|
""" |
|
noise_scale = self.noise_scale |
|
|
|
|
|
obs, action_samples = self._sample(obs, self.inference_samples) |
|
|
|
for i in range(self.iters): |
|
|
|
energies = ebm.forward(obs, action_samples) |
|
probs = F.softmax(-1.0 * energies, dim=-1) |
|
|
|
|
|
idxs = torch.multinomial(probs, self.inference_samples, replacement=True) |
|
action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs] |
|
|
|
|
|
action_samples = action_samples + torch.randn_like(action_samples) * noise_scale |
|
action_samples = action_samples.clamp(min=self.action_bounds[0, :], max=self.action_bounds[1, :]) |
|
|
|
noise_scale *= self.noise_shrink |
|
|
|
|
|
return self._get_best_action_sample(obs, action_samples, ebm) |
|
|
|
|
|
@STOCHASTIC_OPTIMIZER_REGISTRY.register('ardfo') |
|
class AutoRegressiveDFO(DFO): |
|
""" |
|
Overview: |
|
AutoRegressive Derivative-Free Optimizer in paper Implicit Behavioral Cloning. |
|
https://arxiv.org/abs/2109.00137 |
|
Interface: |
|
``__init__``, ``infer`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
noise_scale: float = 0.33, |
|
noise_shrink: float = 0.5, |
|
iters: int = 3, |
|
train_samples: int = 8, |
|
inference_samples: int = 4096, |
|
device: str = 'cpu', |
|
): |
|
""" |
|
Overview: |
|
Initialize the AutoRegressive Derivative-Free Optimizer |
|
Arguments: |
|
- noise_scale (:obj:`float`): Initial noise scale. |
|
- noise_shrink (:obj:`float`): Noise scale shrink rate. |
|
- iters (:obj:`int`): Number of iterations. |
|
- train_samples (:obj:`int`): Number of samples for training. |
|
- inference_samples (:obj:`int`): Number of samples for inference. |
|
- device (:obj:`str`): Device. |
|
""" |
|
super().__init__(noise_scale, noise_shrink, iters, train_samples, inference_samples, device) |
|
|
|
@torch.no_grad() |
|
def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Optimize for the best action conditioned on the current observation. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observations. |
|
- ebm (:obj:`torch.nn.Module`): Energy based model. |
|
Returns: |
|
- best_action_samples (:obj:`torch.Tensor`): Actions. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. |
|
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. |
|
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. |
|
Examples: |
|
>>> obs = torch.randn(2, 4) |
|
>>> ebm = EBM(4, 5) |
|
>>> opt = AutoRegressiveDFO() |
|
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) |
|
>>> best_action_samples = opt.infer(obs, ebm) |
|
""" |
|
noise_scale = self.noise_scale |
|
|
|
|
|
obs, action_samples = self._sample(obs, self.inference_samples) |
|
|
|
for i in range(self.iters): |
|
|
|
for j in range(action_samples.shape[-1]): |
|
|
|
energies = ebm.forward(obs, action_samples)[..., j] |
|
probs = F.softmax(-1.0 * energies, dim=-1) |
|
|
|
|
|
idxs = torch.multinomial(probs, self.inference_samples, replacement=True) |
|
action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs] |
|
|
|
|
|
action_samples[..., j] = action_samples[..., j] + torch.randn_like(action_samples[..., j]) * noise_scale |
|
|
|
action_samples[..., j] = action_samples[..., j].clamp( |
|
min=self.action_bounds[0, j], max=self.action_bounds[1, j] |
|
) |
|
|
|
noise_scale *= self.noise_shrink |
|
|
|
|
|
energies = ebm.forward(obs, action_samples)[..., -1] |
|
probs = F.softmax(-1.0 * energies, dim=-1) |
|
|
|
best_idxs = probs.argmax(dim=-1) |
|
return action_samples[torch.arange(action_samples.size(0)), best_idxs] |
|
|
|
|
|
@STOCHASTIC_OPTIMIZER_REGISTRY.register('mcmc') |
|
class MCMC(StochasticOptimizer): |
|
""" |
|
Overview: |
|
MCMC method as stochastic optimizers in paper Implicit Behavioral Cloning. |
|
https://arxiv.org/abs/2109.00137 |
|
Interface: |
|
``__init__``, ``sample``, ``infer``, ``grad_penalty`` |
|
""" |
|
|
|
class BaseScheduler(ABC): |
|
""" |
|
Overview: |
|
Base class for learning rate scheduler. |
|
Interface: |
|
``get_rate`` |
|
""" |
|
|
|
@abstractmethod |
|
def get_rate(self, index): |
|
""" |
|
Overview: |
|
Abstract method for getting learning rate. |
|
""" |
|
raise NotImplementedError |
|
|
|
class ExponentialScheduler: |
|
""" |
|
Overview: |
|
Exponential learning rate schedule for Langevin sampler. |
|
Interface: |
|
``__init__``, ``get_rate`` |
|
""" |
|
|
|
def __init__(self, init, decay): |
|
""" |
|
Overview: |
|
Initialize the ExponentialScheduler. |
|
Arguments: |
|
- init (:obj:`float`): Initial learning rate. |
|
- decay (:obj:`float`): Decay rate. |
|
""" |
|
self._decay = decay |
|
self._latest_lr = init |
|
|
|
def get_rate(self, index): |
|
""" |
|
Overview: |
|
Get learning rate. Assumes calling sequentially. |
|
Arguments: |
|
- index (:obj:`int`): Current iteration. |
|
""" |
|
del index |
|
lr = self._latest_lr |
|
self._latest_lr *= self._decay |
|
return lr |
|
|
|
class PolynomialScheduler: |
|
""" |
|
Overview: |
|
Polynomial learning rate schedule for Langevin sampler. |
|
Interface: |
|
``__init__``, ``get_rate`` |
|
""" |
|
|
|
def __init__(self, init, final, power, num_steps): |
|
""" |
|
Overview: |
|
Initialize the PolynomialScheduler. |
|
Arguments: |
|
- init (:obj:`float`): Initial learning rate. |
|
- final (:obj:`float`): Final learning rate. |
|
- power (:obj:`float`): Power of polynomial. |
|
- num_steps (:obj:`int`): Number of steps. |
|
""" |
|
self._init = init |
|
self._final = final |
|
self._power = power |
|
self._num_steps = num_steps |
|
|
|
def get_rate(self, index): |
|
""" |
|
Overview: |
|
Get learning rate for index. |
|
Arguments: |
|
- index (:obj:`int`): Current iteration. |
|
""" |
|
if index == -1: |
|
return self._init |
|
return ( |
|
(self._init - self._final) * ((1 - (float(index) / float(self._num_steps - 1))) ** (self._power)) |
|
) + self._final |
|
|
|
def __init__( |
|
self, |
|
iters: int = 100, |
|
use_langevin_negative_samples: bool = True, |
|
train_samples: int = 8, |
|
inference_samples: int = 512, |
|
stepsize_scheduler: dict = dict( |
|
init=0.5, |
|
final=1e-5, |
|
power=2.0, |
|
|
|
), |
|
optimize_again: bool = True, |
|
again_stepsize_scheduler: dict = dict( |
|
init=1e-5, |
|
final=1e-5, |
|
power=2.0, |
|
|
|
), |
|
device: str = 'cpu', |
|
|
|
noise_scale: float = 0.5, |
|
grad_clip=None, |
|
delta_action_clip: float = 0.5, |
|
add_grad_penalty: bool = True, |
|
grad_norm_type: str = 'inf', |
|
grad_margin: float = 1.0, |
|
grad_loss_weight: float = 1.0, |
|
**kwargs, |
|
): |
|
""" |
|
Overview: |
|
Initialize the MCMC. |
|
Arguments: |
|
- iters (:obj:`int`): Number of iterations. |
|
- use_langevin_negative_samples (:obj:`bool`): Whether to use Langevin sampler. |
|
- train_samples (:obj:`int`): Number of samples for training. |
|
- inference_samples (:obj:`int`): Number of samples for inference. |
|
- stepsize_scheduler (:obj:`dict`): Step size scheduler for Langevin sampler. |
|
- optimize_again (:obj:`bool`): Whether to run a second optimization. |
|
- again_stepsize_scheduler (:obj:`dict`): Step size scheduler for the second optimization. |
|
- device (:obj:`str`): Device. |
|
- noise_scale (:obj:`float`): Initial noise scale. |
|
- grad_clip (:obj:`float`): Gradient clip. |
|
- delta_action_clip (:obj:`float`): Action clip. |
|
- add_grad_penalty (:obj:`bool`): Whether to add gradient penalty. |
|
- grad_norm_type (:obj:`str`): Gradient norm type. |
|
- grad_margin (:obj:`float`): Gradient margin. |
|
- grad_loss_weight (:obj:`float`): Gradient loss weight. |
|
""" |
|
self.iters = iters |
|
self.use_langevin_negative_samples = use_langevin_negative_samples |
|
self.train_samples = train_samples |
|
self.inference_samples = inference_samples |
|
self.stepsize_scheduler = stepsize_scheduler |
|
self.optimize_again = optimize_again |
|
self.again_stepsize_scheduler = again_stepsize_scheduler |
|
self.device = device |
|
|
|
self.noise_scale = noise_scale |
|
self.grad_clip = grad_clip |
|
self.delta_action_clip = delta_action_clip |
|
self.add_grad_penalty = add_grad_penalty |
|
self.grad_norm_type = grad_norm_type |
|
self.grad_margin = grad_margin |
|
self.grad_loss_weight = grad_loss_weight |
|
|
|
@staticmethod |
|
def _gradient_wrt_act( |
|
obs: torch.Tensor, |
|
action: torch.Tensor, |
|
ebm: nn.Module, |
|
create_graph: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate gradient w.r.t action. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observations. |
|
- action (:obj:`torch.Tensor`): Actions. |
|
- ebm (:obj:`torch.nn.Module`): Energy based model. |
|
- create_graph (:obj:`bool`): Whether to create graph. |
|
Returns: |
|
- grad (:obj:`torch.Tensor`): Gradient w.r.t action. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. |
|
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`. |
|
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. |
|
- grad (:obj:`torch.Tensor`): :math:`(B, N, A)`. |
|
""" |
|
action.requires_grad_(True) |
|
energy = ebm.forward(obs, action).sum() |
|
|
|
|
|
grad = torch.autograd.grad(energy, action, create_graph=create_graph)[0] |
|
action.requires_grad_(False) |
|
return grad |
|
|
|
def grad_penalty(self, obs: torch.Tensor, action: torch.Tensor, ebm: nn.Module) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate gradient penalty. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observations. |
|
- action (:obj:`torch.Tensor`): Actions. |
|
- ebm (:obj:`torch.nn.Module`): Energy based model. |
|
Returns: |
|
- loss (:obj:`torch.Tensor`): Gradient penalty. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, N+1, O)`. |
|
- action (:obj:`torch.Tensor`): :math:`(B, N+1, A)`. |
|
- ebm (:obj:`torch.nn.Module`): :math:`(B, N+1, O)`. |
|
- loss (:obj:`torch.Tensor`): :math:`(B, )`. |
|
""" |
|
if not self.add_grad_penalty: |
|
return 0. |
|
|
|
de_dact = MCMC._gradient_wrt_act(obs, action, ebm, create_graph=True) |
|
|
|
def compute_grad_norm(grad_norm_type, de_dact) -> torch.Tensor: |
|
|
|
|
|
grad_norm_type_to_ord = { |
|
'1': 1, |
|
'2': 2, |
|
'inf': float('inf'), |
|
} |
|
ord = grad_norm_type_to_ord[grad_norm_type] |
|
return torch.linalg.norm(de_dact, ord, dim=-1) |
|
|
|
|
|
grad_norms = compute_grad_norm(self.grad_norm_type, de_dact) |
|
grad_norms = grad_norms - self.grad_margin |
|
grad_norms = grad_norms.clamp(min=0., max=1e10) |
|
grad_norms = grad_norms.pow(2) |
|
|
|
grad_loss = grad_norms.mean() |
|
return grad_loss * self.grad_loss_weight |
|
|
|
|
|
|
|
@no_ebm_grad() |
|
def _langevin_step(self, obs: torch.Tensor, action: torch.Tensor, stepsize: float, ebm: nn.Module) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Run one langevin MCMC step. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observations. |
|
- action (:obj:`torch.Tensor`): Actions. |
|
- stepsize (:obj:`float`): Step size. |
|
- ebm (:obj:`torch.nn.Module`): Energy based model. |
|
Returns: |
|
- action (:obj:`torch.Tensor`): Actions. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. |
|
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`. |
|
- stepsize (:obj:`float`): :math:`(B, )`. |
|
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. |
|
""" |
|
l_lambda = 1.0 |
|
de_dact = MCMC._gradient_wrt_act(obs, action, ebm) |
|
|
|
if self.grad_clip: |
|
de_dact = de_dact.clamp(min=-self.grad_clip, max=self.grad_clip) |
|
|
|
gradient_scale = 0.5 |
|
de_dact = (gradient_scale * l_lambda * de_dact + torch.randn_like(de_dact) * l_lambda * self.noise_scale) |
|
|
|
delta_action = stepsize * de_dact |
|
delta_action_clip = self.delta_action_clip * 0.5 * (self.action_bounds[1] - self.action_bounds[0]) |
|
delta_action = delta_action.clamp(min=-delta_action_clip, max=delta_action_clip) |
|
|
|
action = action - delta_action |
|
action = action.clamp(min=self.action_bounds[0], max=self.action_bounds[1]) |
|
|
|
return action |
|
|
|
@no_ebm_grad() |
|
def _langevin_action_given_obs( |
|
self, |
|
obs: torch.Tensor, |
|
action: torch.Tensor, |
|
ebm: nn.Module, |
|
scheduler: BaseScheduler = None |
|
) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Run langevin MCMC for `self.iters` steps. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observations. |
|
- action (:obj:`torch.Tensor`): Actions. |
|
- ebm (:obj:`torch.nn.Module`): Energy based model. |
|
- scheduler (:obj:`BaseScheduler`): Learning rate scheduler. |
|
Returns: |
|
- action (:obj:`torch.Tensor`): Actions. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. |
|
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`. |
|
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. |
|
""" |
|
if not scheduler: |
|
self.stepsize_scheduler['num_steps'] = self.iters |
|
scheduler = MCMC.PolynomialScheduler(**self.stepsize_scheduler) |
|
stepsize = scheduler.get_rate(-1) |
|
for i in range(self.iters): |
|
action = self._langevin_step(obs, action, stepsize, ebm) |
|
stepsize = scheduler.get_rate(i) |
|
return action |
|
|
|
@no_ebm_grad() |
|
def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Create tiled observations and sample counter-negatives for InfoNCE loss. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observations. |
|
- ebm (:obj:`torch.nn.Module`): Energy based model. |
|
Returns: |
|
- tiled_obs (:obj:`torch.Tensor`): Tiled observations. |
|
- action_samples (:obj:`torch.Tensor`): Action samples. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. |
|
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. |
|
- tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. |
|
- action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. |
|
Examples: |
|
>>> obs = torch.randn(2, 4) |
|
>>> ebm = EBM(4, 5) |
|
>>> opt = MCMC() |
|
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) |
|
>>> tiled_obs, action_samples = opt.sample(obs, ebm) |
|
""" |
|
obs, uniform_action_samples = self._sample(obs, self.train_samples) |
|
if not self.use_langevin_negative_samples: |
|
return obs, uniform_action_samples |
|
langevin_action_samples = self._langevin_action_given_obs(obs, uniform_action_samples, ebm) |
|
return obs, langevin_action_samples |
|
|
|
@no_ebm_grad() |
|
def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Optimize for the best action conditioned on the current observation. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observations. |
|
- ebm (:obj:`torch.nn.Module`): Energy based model. |
|
Returns: |
|
- best_action_samples (:obj:`torch.Tensor`): Actions. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. |
|
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. |
|
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. |
|
Examples: |
|
>>> obs = torch.randn(2, 4) |
|
>>> ebm = EBM(4, 5) |
|
>>> opt = MCMC() |
|
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) |
|
>>> best_action_samples = opt.infer(obs, ebm) |
|
""" |
|
|
|
obs, uniform_action_samples = self._sample(obs, self.inference_samples) |
|
action_samples = self._langevin_action_given_obs( |
|
obs, |
|
uniform_action_samples, |
|
ebm, |
|
) |
|
|
|
|
|
if self.optimize_again: |
|
self.again_stepsize_scheduler['num_steps'] = self.iters |
|
action_samples = self._langevin_action_given_obs( |
|
obs, |
|
action_samples, |
|
ebm, |
|
scheduler=MCMC.PolynomialScheduler(**self.again_stepsize_scheduler), |
|
) |
|
|
|
|
|
return self._get_best_action_sample(obs, action_samples, ebm) |
|
|
|
|
|
@MODEL_REGISTRY.register('ebm') |
|
class EBM(nn.Module): |
|
""" |
|
Overview: |
|
Energy based model. |
|
Interface: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
obs_shape: int, |
|
action_shape: int, |
|
hidden_size: int = 512, |
|
hidden_layer_num: int = 4, |
|
**kwargs, |
|
): |
|
""" |
|
Overview: |
|
Initialize the EBM. |
|
Arguments: |
|
- obs_shape (:obj:`int`): Observation shape. |
|
- action_shape (:obj:`int`): Action shape. |
|
- hidden_size (:obj:`int`): Hidden size. |
|
- hidden_layer_num (:obj:`int`): Number of hidden layers. |
|
""" |
|
super().__init__() |
|
input_size = obs_shape + action_shape |
|
self.net = nn.Sequential( |
|
nn.Linear(input_size, hidden_size), nn.ReLU(), |
|
RegressionHead( |
|
hidden_size, |
|
1, |
|
hidden_layer_num, |
|
final_tanh=False, |
|
) |
|
) |
|
|
|
def forward(self, obs, action): |
|
""" |
|
Overview: |
|
Forward computation graph of EBM. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). |
|
- action (:obj:`torch.Tensor`): Action of shape (B, N, A). |
|
Returns: |
|
- pred (:obj:`torch.Tensor`): Energy of shape (B, N). |
|
Examples: |
|
>>> obs = torch.randn(2, 3, 4) |
|
>>> action = torch.randn(2, 3, 5) |
|
>>> ebm = EBM(4, 5) |
|
>>> pred = ebm(obs, action) |
|
""" |
|
x = torch.cat([obs, action], -1) |
|
x = self.net(x) |
|
return x['pred'] |
|
|
|
|
|
@MODEL_REGISTRY.register('arebm') |
|
class AutoregressiveEBM(nn.Module): |
|
""" |
|
Overview: |
|
Autoregressive energy based model. |
|
Interface: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
obs_shape: int, |
|
action_shape: int, |
|
hidden_size: int = 512, |
|
hidden_layer_num: int = 4, |
|
): |
|
""" |
|
Overview: |
|
Initialize the AutoregressiveEBM. |
|
Arguments: |
|
- obs_shape (:obj:`int`): Observation shape. |
|
- action_shape (:obj:`int`): Action shape. |
|
- hidden_size (:obj:`int`): Hidden size. |
|
- hidden_layer_num (:obj:`int`): Number of hidden layers. |
|
""" |
|
super().__init__() |
|
self.ebm_list = nn.ModuleList() |
|
for i in range(action_shape): |
|
self.ebm_list.append(EBM(obs_shape, i + 1, hidden_size, hidden_layer_num)) |
|
|
|
def forward(self, obs, action): |
|
""" |
|
Overview: |
|
Forward computation graph of AutoregressiveEBM. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). |
|
- action (:obj:`torch.Tensor`): Action of shape (B, N, A). |
|
Returns: |
|
- pred (:obj:`torch.Tensor`): Energy of shape (B, N, A). |
|
Examples: |
|
>>> obs = torch.randn(2, 3, 4) |
|
>>> action = torch.randn(2, 3, 5) |
|
>>> arebm = AutoregressiveEBM(4, 5) |
|
>>> pred = arebm(obs, action) |
|
""" |
|
output_list = [] |
|
for i, ebm in enumerate(self.ebm_list): |
|
output_list.append(ebm(obs, action[..., :i + 1])) |
|
return torch.stack(output_list, axis=-1) |
|
|