zjowowen's picture
init space
079c32c
"""
Vanilla DFO and EBM are adapted from https://github.com/kevinzakka/ibc.
MCMC is adapted from https://github.com/google-research/ibc.
"""
from typing import Callable, Tuple
from functools import wraps
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC, abstractmethod
from ding.utils import MODEL_REGISTRY, STOCHASTIC_OPTIMIZER_REGISTRY
from ding.torch_utils import unsqueeze_repeat
from ding.model.wrapper import IModelWrapper
from ding.model.common import RegressionHead
def create_stochastic_optimizer(device: str, stochastic_optimizer_config: dict):
"""
Overview:
Create stochastic optimizer.
Arguments:
- device (:obj:`str`): Device.
- stochastic_optimizer_config (:obj:`dict`): Stochastic optimizer config.
"""
return STOCHASTIC_OPTIMIZER_REGISTRY.build(
stochastic_optimizer_config.pop("type"), device=device, **stochastic_optimizer_config
)
def no_ebm_grad():
"""Wrapper that disables energy based model gradients"""
def ebm_disable_grad_wrapper(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
ebm = args[-1]
assert isinstance(ebm, (IModelWrapper, nn.Module)),\
'Make sure ebm is the last positional arguments.'
ebm.requires_grad_(False)
result = func(*args, **kwargs)
ebm.requires_grad_(True)
return result
return wrapper
return ebm_disable_grad_wrapper
class StochasticOptimizer(ABC):
"""
Overview:
Base class for stochastic optimizers.
Interface:
``__init__``, ``_sample``, ``_get_best_action_sample``, ``set_action_bounds``, ``sample``, ``infer``
"""
def _sample(self, obs: torch.Tensor, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Drawing action samples from the uniform random distribution \
and tiling observations to the same shape as action samples.
Arguments:
- obs (:obj:`torch.Tensor`): Observation.
- num_samples (:obj:`int`): The number of negative samples.
Returns:
- tiled_obs (:obj:`torch.Tensor`): Observations tiled.
- action (:obj:`torch.Tensor`): Action sampled.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, O)`.
- num_samples (:obj:`int`): :math:`N`.
- tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
Examples:
>>> obs = torch.randn(2, 4)
>>> opt = StochasticOptimizer()
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
>>> tiled_obs, action = opt._sample(obs, 8)
"""
size = (obs.shape[0], num_samples, self.action_bounds.shape[1])
low, high = self.action_bounds[0, :], self.action_bounds[1, :]
action_samples = low + (high - low) * torch.rand(size).to(self.device)
tiled_obs = unsqueeze_repeat(obs, num_samples, 1)
return tiled_obs, action_samples
@staticmethod
@torch.no_grad()
def _get_best_action_sample(obs: torch.Tensor, action_samples: torch.Tensor, ebm: nn.Module):
"""
Overview:
Return one action for each batch with highest probability (lowest energy).
Arguments:
- obs (:obj:`torch.Tensor`): Observation.
- action_samples (:obj:`torch.Tensor`): Action from uniform distributions.
Returns:
- best_action_samples (:obj:`torch.Tensor`): Best action.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, O)`.
- action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`.
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
Examples:
>>> obs = torch.randn(2, 4)
>>> action_samples = torch.randn(2, 8, 5)
>>> ebm = EBM(4, 5)
>>> opt = StochasticOptimizer()
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
>>> best_action_samples = opt._get_best_action_sample(obs, action_samples, ebm)
"""
# (B, N)
energies = ebm.forward(obs, action_samples)
probs = F.softmax(-1.0 * energies, dim=-1)
# (B, )
best_idxs = probs.argmax(dim=-1)
return action_samples[torch.arange(action_samples.size(0)), best_idxs]
def set_action_bounds(self, action_bounds: np.ndarray):
"""
Overview:
Set action bounds calculated from the dataset statistics.
Arguments:
- action_bounds (:obj:`np.ndarray`): Array of shape (2, A), \
where action_bounds[0] is lower bound and action_bounds[1] is upper bound.
Returns:
- action_bounds (:obj:`torch.Tensor`): Action bounds.
Shapes:
- action_bounds (:obj:`np.ndarray`): :math:`(2, A)`.
- action_bounds (:obj:`torch.Tensor`): :math:`(2, A)`.
Examples:
>>> opt = StochasticOptimizer()
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
"""
self.action_bounds = torch.as_tensor(action_bounds, dtype=torch.float32).to(self.device)
@abstractmethod
def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Create tiled observations and sample counter-negatives for InfoNCE loss.
Arguments:
- obs (:obj:`torch.Tensor`): Observations.
- ebm (:obj:`torch.nn.Module`): Energy based model.
Returns:
- tiled_obs (:obj:`torch.Tensor`): Tiled observations.
- action (:obj:`torch.Tensor`): Actions.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, O)`.
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
- tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
.. note:: In the case of derivative-free optimization, this function will simply call _sample.
"""
raise NotImplementedError
@abstractmethod
def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
"""
Overview:
Optimize for the best action conditioned on the current observation.
Arguments:
- obs (:obj:`torch.Tensor`): Observations.
- ebm (:obj:`torch.nn.Module`): Energy based model.
Returns:
- best_action_samples (:obj:`torch.Tensor`): Best actions.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, O)`.
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
"""
raise NotImplementedError
@STOCHASTIC_OPTIMIZER_REGISTRY.register('dfo')
class DFO(StochasticOptimizer):
"""
Overview:
Derivative-Free Optimizer in paper Implicit Behavioral Cloning.
https://arxiv.org/abs/2109.00137
Interface:
``init``, ``sample``, ``infer``
"""
def __init__(
self,
noise_scale: float = 0.33,
noise_shrink: float = 0.5,
iters: int = 3,
train_samples: int = 8,
inference_samples: int = 16384,
device: str = 'cpu',
):
"""
Overview:
Initialize the Derivative-Free Optimizer
Arguments:
- noise_scale (:obj:`float`): Initial noise scale.
- noise_shrink (:obj:`float`): Noise scale shrink rate.
- iters (:obj:`int`): Number of iterations.
- train_samples (:obj:`int`): Number of samples for training.
- inference_samples (:obj:`int`): Number of samples for inference.
- device (:obj:`str`): Device.
"""
self.action_bounds = None
self.noise_scale = noise_scale
self.noise_shrink = noise_shrink
self.iters = iters
self.train_samples = train_samples
self.inference_samples = inference_samples
self.device = device
def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Drawing action samples from the uniform random distribution \
and tiling observations to the same shape as action samples.
Arguments:
- obs (:obj:`torch.Tensor`): Observations.
- ebm (:obj:`torch.nn.Module`): Energy based model.
Returns:
- tiled_obs (:obj:`torch.Tensor`): Tiled observation.
- action_samples (:obj:`torch.Tensor`): Action samples.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, O)`.
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
- tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
- action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`.
Examples:
>>> obs = torch.randn(2, 4)
>>> ebm = EBM(4, 5)
>>> opt = DFO()
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
>>> tiled_obs, action_samples = opt.sample(obs, ebm)
"""
return self._sample(obs, self.train_samples)
@torch.no_grad()
def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
"""
Overview:
Optimize for the best action conditioned on the current observation.
Arguments:
- obs (:obj:`torch.Tensor`): Observations.
- ebm (:obj:`torch.nn.Module`): Energy based model.
Returns:
- best_action_samples (:obj:`torch.Tensor`): Actions.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, O)`.
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
Examples:
>>> obs = torch.randn(2, 4)
>>> ebm = EBM(4, 5)
>>> opt = DFO()
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
>>> best_action_samples = opt.infer(obs, ebm)
"""
noise_scale = self.noise_scale
# (B, N, O), (B, N, A)
obs, action_samples = self._sample(obs, self.inference_samples)
for i in range(self.iters):
# (B, N)
energies = ebm.forward(obs, action_samples)
probs = F.softmax(-1.0 * energies, dim=-1)
# Resample with replacement.
idxs = torch.multinomial(probs, self.inference_samples, replacement=True)
action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs]
# Add noise and clip to target bounds.
action_samples = action_samples + torch.randn_like(action_samples) * noise_scale
action_samples = action_samples.clamp(min=self.action_bounds[0, :], max=self.action_bounds[1, :])
noise_scale *= self.noise_shrink
# Return target with highest probability.
return self._get_best_action_sample(obs, action_samples, ebm)
@STOCHASTIC_OPTIMIZER_REGISTRY.register('ardfo')
class AutoRegressiveDFO(DFO):
"""
Overview:
AutoRegressive Derivative-Free Optimizer in paper Implicit Behavioral Cloning.
https://arxiv.org/abs/2109.00137
Interface:
``__init__``, ``infer``
"""
def __init__(
self,
noise_scale: float = 0.33,
noise_shrink: float = 0.5,
iters: int = 3,
train_samples: int = 8,
inference_samples: int = 4096,
device: str = 'cpu',
):
"""
Overview:
Initialize the AutoRegressive Derivative-Free Optimizer
Arguments:
- noise_scale (:obj:`float`): Initial noise scale.
- noise_shrink (:obj:`float`): Noise scale shrink rate.
- iters (:obj:`int`): Number of iterations.
- train_samples (:obj:`int`): Number of samples for training.
- inference_samples (:obj:`int`): Number of samples for inference.
- device (:obj:`str`): Device.
"""
super().__init__(noise_scale, noise_shrink, iters, train_samples, inference_samples, device)
@torch.no_grad()
def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
"""
Overview:
Optimize for the best action conditioned on the current observation.
Arguments:
- obs (:obj:`torch.Tensor`): Observations.
- ebm (:obj:`torch.nn.Module`): Energy based model.
Returns:
- best_action_samples (:obj:`torch.Tensor`): Actions.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, O)`.
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
Examples:
>>> obs = torch.randn(2, 4)
>>> ebm = EBM(4, 5)
>>> opt = AutoRegressiveDFO()
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
>>> best_action_samples = opt.infer(obs, ebm)
"""
noise_scale = self.noise_scale
# (B, N, O), (B, N, A)
obs, action_samples = self._sample(obs, self.inference_samples)
for i in range(self.iters):
# j: action_dim index
for j in range(action_samples.shape[-1]):
# (B, N)
energies = ebm.forward(obs, action_samples)[..., j]
probs = F.softmax(-1.0 * energies, dim=-1)
# Resample with replacement.
idxs = torch.multinomial(probs, self.inference_samples, replacement=True)
action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs]
# Add noise and clip to target bounds.
action_samples[..., j] = action_samples[..., j] + torch.randn_like(action_samples[..., j]) * noise_scale
action_samples[..., j] = action_samples[..., j].clamp(
min=self.action_bounds[0, j], max=self.action_bounds[1, j]
)
noise_scale *= self.noise_shrink
# (B, N)
energies = ebm.forward(obs, action_samples)[..., -1]
probs = F.softmax(-1.0 * energies, dim=-1)
# (B, )
best_idxs = probs.argmax(dim=-1)
return action_samples[torch.arange(action_samples.size(0)), best_idxs]
@STOCHASTIC_OPTIMIZER_REGISTRY.register('mcmc')
class MCMC(StochasticOptimizer):
"""
Overview:
MCMC method as stochastic optimizers in paper Implicit Behavioral Cloning.
https://arxiv.org/abs/2109.00137
Interface:
``__init__``, ``sample``, ``infer``, ``grad_penalty``
"""
class BaseScheduler(ABC):
"""
Overview:
Base class for learning rate scheduler.
Interface:
``get_rate``
"""
@abstractmethod
def get_rate(self, index):
"""
Overview:
Abstract method for getting learning rate.
"""
raise NotImplementedError
class ExponentialScheduler:
"""
Overview:
Exponential learning rate schedule for Langevin sampler.
Interface:
``__init__``, ``get_rate``
"""
def __init__(self, init, decay):
"""
Overview:
Initialize the ExponentialScheduler.
Arguments:
- init (:obj:`float`): Initial learning rate.
- decay (:obj:`float`): Decay rate.
"""
self._decay = decay
self._latest_lr = init
def get_rate(self, index):
"""
Overview:
Get learning rate. Assumes calling sequentially.
Arguments:
- index (:obj:`int`): Current iteration.
"""
del index
lr = self._latest_lr
self._latest_lr *= self._decay
return lr
class PolynomialScheduler:
"""
Overview:
Polynomial learning rate schedule for Langevin sampler.
Interface:
``__init__``, ``get_rate``
"""
def __init__(self, init, final, power, num_steps):
"""
Overview:
Initialize the PolynomialScheduler.
Arguments:
- init (:obj:`float`): Initial learning rate.
- final (:obj:`float`): Final learning rate.
- power (:obj:`float`): Power of polynomial.
- num_steps (:obj:`int`): Number of steps.
"""
self._init = init
self._final = final
self._power = power
self._num_steps = num_steps
def get_rate(self, index):
"""
Overview:
Get learning rate for index.
Arguments:
- index (:obj:`int`): Current iteration.
"""
if index == -1:
return self._init
return (
(self._init - self._final) * ((1 - (float(index) / float(self._num_steps - 1))) ** (self._power))
) + self._final
def __init__(
self,
iters: int = 100,
use_langevin_negative_samples: bool = True,
train_samples: int = 8,
inference_samples: int = 512,
stepsize_scheduler: dict = dict(
init=0.5,
final=1e-5,
power=2.0,
# num_steps,
),
optimize_again: bool = True,
again_stepsize_scheduler: dict = dict(
init=1e-5,
final=1e-5,
power=2.0,
# num_steps,
),
device: str = 'cpu',
# langevin_step
noise_scale: float = 0.5,
grad_clip=None,
delta_action_clip: float = 0.5,
add_grad_penalty: bool = True,
grad_norm_type: str = 'inf',
grad_margin: float = 1.0,
grad_loss_weight: float = 1.0,
**kwargs,
):
"""
Overview:
Initialize the MCMC.
Arguments:
- iters (:obj:`int`): Number of iterations.
- use_langevin_negative_samples (:obj:`bool`): Whether to use Langevin sampler.
- train_samples (:obj:`int`): Number of samples for training.
- inference_samples (:obj:`int`): Number of samples for inference.
- stepsize_scheduler (:obj:`dict`): Step size scheduler for Langevin sampler.
- optimize_again (:obj:`bool`): Whether to run a second optimization.
- again_stepsize_scheduler (:obj:`dict`): Step size scheduler for the second optimization.
- device (:obj:`str`): Device.
- noise_scale (:obj:`float`): Initial noise scale.
- grad_clip (:obj:`float`): Gradient clip.
- delta_action_clip (:obj:`float`): Action clip.
- add_grad_penalty (:obj:`bool`): Whether to add gradient penalty.
- grad_norm_type (:obj:`str`): Gradient norm type.
- grad_margin (:obj:`float`): Gradient margin.
- grad_loss_weight (:obj:`float`): Gradient loss weight.
"""
self.iters = iters
self.use_langevin_negative_samples = use_langevin_negative_samples
self.train_samples = train_samples
self.inference_samples = inference_samples
self.stepsize_scheduler = stepsize_scheduler
self.optimize_again = optimize_again
self.again_stepsize_scheduler = again_stepsize_scheduler
self.device = device
self.noise_scale = noise_scale
self.grad_clip = grad_clip
self.delta_action_clip = delta_action_clip
self.add_grad_penalty = add_grad_penalty
self.grad_norm_type = grad_norm_type
self.grad_margin = grad_margin
self.grad_loss_weight = grad_loss_weight
@staticmethod
def _gradient_wrt_act(
obs: torch.Tensor,
action: torch.Tensor,
ebm: nn.Module,
create_graph: bool = False,
) -> torch.Tensor:
"""
Overview:
Calculate gradient w.r.t action.
Arguments:
- obs (:obj:`torch.Tensor`): Observations.
- action (:obj:`torch.Tensor`): Actions.
- ebm (:obj:`torch.nn.Module`): Energy based model.
- create_graph (:obj:`bool`): Whether to create graph.
Returns:
- grad (:obj:`torch.Tensor`): Gradient w.r.t action.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
- grad (:obj:`torch.Tensor`): :math:`(B, N, A)`.
"""
action.requires_grad_(True)
energy = ebm.forward(obs, action).sum()
# `create_graph` set to `True` when second order derivative
# is needed i.e, d(de/da)/d_param
grad = torch.autograd.grad(energy, action, create_graph=create_graph)[0]
action.requires_grad_(False)
return grad
def grad_penalty(self, obs: torch.Tensor, action: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
"""
Overview:
Calculate gradient penalty.
Arguments:
- obs (:obj:`torch.Tensor`): Observations.
- action (:obj:`torch.Tensor`): Actions.
- ebm (:obj:`torch.nn.Module`): Energy based model.
Returns:
- loss (:obj:`torch.Tensor`): Gradient penalty.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N+1, O)`.
- action (:obj:`torch.Tensor`): :math:`(B, N+1, A)`.
- ebm (:obj:`torch.nn.Module`): :math:`(B, N+1, O)`.
- loss (:obj:`torch.Tensor`): :math:`(B, )`.
"""
if not self.add_grad_penalty:
return 0.
# (B, N+1, A), this gradient is differentiable w.r.t model parameters
de_dact = MCMC._gradient_wrt_act(obs, action, ebm, create_graph=True)
def compute_grad_norm(grad_norm_type, de_dact) -> torch.Tensor:
# de_deact: B, N+1, A
# return: B, N+1
grad_norm_type_to_ord = {
'1': 1,
'2': 2,
'inf': float('inf'),
}
ord = grad_norm_type_to_ord[grad_norm_type]
return torch.linalg.norm(de_dact, ord, dim=-1)
# (B, N+1)
grad_norms = compute_grad_norm(self.grad_norm_type, de_dact)
grad_norms = grad_norms - self.grad_margin
grad_norms = grad_norms.clamp(min=0., max=1e10)
grad_norms = grad_norms.pow(2)
grad_loss = grad_norms.mean()
return grad_loss * self.grad_loss_weight
# can not use @torch.no_grad() during the inference
# because we need to calculate gradient w.r.t inputs as MCMC updates.
@no_ebm_grad()
def _langevin_step(self, obs: torch.Tensor, action: torch.Tensor, stepsize: float, ebm: nn.Module) -> torch.Tensor:
"""
Overview:
Run one langevin MCMC step.
Arguments:
- obs (:obj:`torch.Tensor`): Observations.
- action (:obj:`torch.Tensor`): Actions.
- stepsize (:obj:`float`): Step size.
- ebm (:obj:`torch.nn.Module`): Energy based model.
Returns:
- action (:obj:`torch.Tensor`): Actions.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
- stepsize (:obj:`float`): :math:`(B, )`.
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
"""
l_lambda = 1.0
de_dact = MCMC._gradient_wrt_act(obs, action, ebm)
if self.grad_clip:
de_dact = de_dact.clamp(min=-self.grad_clip, max=self.grad_clip)
gradient_scale = 0.5
de_dact = (gradient_scale * l_lambda * de_dact + torch.randn_like(de_dact) * l_lambda * self.noise_scale)
delta_action = stepsize * de_dact
delta_action_clip = self.delta_action_clip * 0.5 * (self.action_bounds[1] - self.action_bounds[0])
delta_action = delta_action.clamp(min=-delta_action_clip, max=delta_action_clip)
action = action - delta_action
action = action.clamp(min=self.action_bounds[0], max=self.action_bounds[1])
return action
@no_ebm_grad()
def _langevin_action_given_obs(
self,
obs: torch.Tensor,
action: torch.Tensor,
ebm: nn.Module,
scheduler: BaseScheduler = None
) -> torch.Tensor:
"""
Overview:
Run langevin MCMC for `self.iters` steps.
Arguments:
- obs (:obj:`torch.Tensor`): Observations.
- action (:obj:`torch.Tensor`): Actions.
- ebm (:obj:`torch.nn.Module`): Energy based model.
- scheduler (:obj:`BaseScheduler`): Learning rate scheduler.
Returns:
- action (:obj:`torch.Tensor`): Actions.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
"""
if not scheduler:
self.stepsize_scheduler['num_steps'] = self.iters
scheduler = MCMC.PolynomialScheduler(**self.stepsize_scheduler)
stepsize = scheduler.get_rate(-1)
for i in range(self.iters):
action = self._langevin_step(obs, action, stepsize, ebm)
stepsize = scheduler.get_rate(i)
return action
@no_ebm_grad()
def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Create tiled observations and sample counter-negatives for InfoNCE loss.
Arguments:
- obs (:obj:`torch.Tensor`): Observations.
- ebm (:obj:`torch.nn.Module`): Energy based model.
Returns:
- tiled_obs (:obj:`torch.Tensor`): Tiled observations.
- action_samples (:obj:`torch.Tensor`): Action samples.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, O)`.
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
- tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
- action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`.
Examples:
>>> obs = torch.randn(2, 4)
>>> ebm = EBM(4, 5)
>>> opt = MCMC()
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
>>> tiled_obs, action_samples = opt.sample(obs, ebm)
"""
obs, uniform_action_samples = self._sample(obs, self.train_samples)
if not self.use_langevin_negative_samples:
return obs, uniform_action_samples
langevin_action_samples = self._langevin_action_given_obs(obs, uniform_action_samples, ebm)
return obs, langevin_action_samples
@no_ebm_grad()
def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
"""
Overview:
Optimize for the best action conditioned on the current observation.
Arguments:
- obs (:obj:`torch.Tensor`): Observations.
- ebm (:obj:`torch.nn.Module`): Energy based model.
Returns:
- best_action_samples (:obj:`torch.Tensor`): Actions.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, O)`.
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
Examples:
>>> obs = torch.randn(2, 4)
>>> ebm = EBM(4, 5)
>>> opt = MCMC()
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
>>> best_action_samples = opt.infer(obs, ebm)
"""
# (B, N, O), (B, N, A)
obs, uniform_action_samples = self._sample(obs, self.inference_samples)
action_samples = self._langevin_action_given_obs(
obs,
uniform_action_samples,
ebm,
)
# Run a second optimization, a trick for more precise inference
if self.optimize_again:
self.again_stepsize_scheduler['num_steps'] = self.iters
action_samples = self._langevin_action_given_obs(
obs,
action_samples,
ebm,
scheduler=MCMC.PolynomialScheduler(**self.again_stepsize_scheduler),
)
# action_samples: B, N, A
return self._get_best_action_sample(obs, action_samples, ebm)
@MODEL_REGISTRY.register('ebm')
class EBM(nn.Module):
"""
Overview:
Energy based model.
Interface:
``__init__``, ``forward``
"""
def __init__(
self,
obs_shape: int,
action_shape: int,
hidden_size: int = 512,
hidden_layer_num: int = 4,
**kwargs,
):
"""
Overview:
Initialize the EBM.
Arguments:
- obs_shape (:obj:`int`): Observation shape.
- action_shape (:obj:`int`): Action shape.
- hidden_size (:obj:`int`): Hidden size.
- hidden_layer_num (:obj:`int`): Number of hidden layers.
"""
super().__init__()
input_size = obs_shape + action_shape
self.net = nn.Sequential(
nn.Linear(input_size, hidden_size), nn.ReLU(),
RegressionHead(
hidden_size,
1,
hidden_layer_num,
final_tanh=False,
)
)
def forward(self, obs, action):
"""
Overview:
Forward computation graph of EBM.
Arguments:
- obs (:obj:`torch.Tensor`): Observation of shape (B, N, O).
- action (:obj:`torch.Tensor`): Action of shape (B, N, A).
Returns:
- pred (:obj:`torch.Tensor`): Energy of shape (B, N).
Examples:
>>> obs = torch.randn(2, 3, 4)
>>> action = torch.randn(2, 3, 5)
>>> ebm = EBM(4, 5)
>>> pred = ebm(obs, action)
"""
x = torch.cat([obs, action], -1)
x = self.net(x)
return x['pred']
@MODEL_REGISTRY.register('arebm')
class AutoregressiveEBM(nn.Module):
"""
Overview:
Autoregressive energy based model.
Interface:
``__init__``, ``forward``
"""
def __init__(
self,
obs_shape: int,
action_shape: int,
hidden_size: int = 512,
hidden_layer_num: int = 4,
):
"""
Overview:
Initialize the AutoregressiveEBM.
Arguments:
- obs_shape (:obj:`int`): Observation shape.
- action_shape (:obj:`int`): Action shape.
- hidden_size (:obj:`int`): Hidden size.
- hidden_layer_num (:obj:`int`): Number of hidden layers.
"""
super().__init__()
self.ebm_list = nn.ModuleList()
for i in range(action_shape):
self.ebm_list.append(EBM(obs_shape, i + 1, hidden_size, hidden_layer_num))
def forward(self, obs, action):
"""
Overview:
Forward computation graph of AutoregressiveEBM.
Arguments:
- obs (:obj:`torch.Tensor`): Observation of shape (B, N, O).
- action (:obj:`torch.Tensor`): Action of shape (B, N, A).
Returns:
- pred (:obj:`torch.Tensor`): Energy of shape (B, N, A).
Examples:
>>> obs = torch.randn(2, 3, 4)
>>> action = torch.randn(2, 3, 5)
>>> arebm = AutoregressiveEBM(4, 5)
>>> pred = arebm(obs, action)
"""
output_list = []
for i, ebm in enumerate(self.ebm_list):
output_list.append(ebm(obs, action[..., :i + 1]))
return torch.stack(output_list, axis=-1)