|
"""Credit: Note the following vae model is modified from https://github.com/AntixK/PyTorch-VAE""" |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
from torch import nn |
|
from abc import abstractmethod |
|
from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple, Optional |
|
from ding.utils.type_helper import Tensor |
|
|
|
|
|
class VanillaVAE(nn.Module): |
|
""" |
|
Overview: |
|
Implementation of Vanilla variational autoencoder for action reconstruction. |
|
Interfaces: |
|
``__init__``, ``encode``, ``decode``, ``decode_with_obs``, ``reparameterize``, \ |
|
``forward``, ``loss_function`` . |
|
""" |
|
|
|
def __init__( |
|
self, |
|
action_shape: int, |
|
obs_shape: int, |
|
latent_size: int, |
|
hidden_dims: List = [256, 256], |
|
**kwargs |
|
) -> None: |
|
super(VanillaVAE, self).__init__() |
|
self.action_shape = action_shape |
|
self.obs_shape = obs_shape |
|
self.latent_size = latent_size |
|
self.hidden_dims = hidden_dims |
|
|
|
|
|
self.encode_action_head = nn.Sequential(nn.Linear(self.action_shape, hidden_dims[0]), nn.ReLU()) |
|
self.encode_obs_head = nn.Sequential(nn.Linear(self.obs_shape, hidden_dims[0]), nn.ReLU()) |
|
|
|
self.encode_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[1]), nn.ReLU()) |
|
self.encode_mu_head = nn.Linear(hidden_dims[1], latent_size) |
|
self.encode_logvar_head = nn.Linear(hidden_dims[1], latent_size) |
|
|
|
|
|
self.decode_action_head = nn.Sequential(nn.Linear(latent_size, hidden_dims[-1]), nn.ReLU()) |
|
self.decode_common = nn.Sequential(nn.Linear(hidden_dims[-1], hidden_dims[-2]), nn.ReLU()) |
|
|
|
self.decode_reconst_action_head = nn.Sequential(nn.Linear(hidden_dims[-2], self.action_shape), nn.Tanh()) |
|
|
|
|
|
self.decode_prediction_head_layer1 = nn.Sequential(nn.Linear(hidden_dims[-2], hidden_dims[-2]), nn.ReLU()) |
|
self.decode_prediction_head_layer2 = nn.Linear(hidden_dims[-2], self.obs_shape) |
|
|
|
self.obs_encoding = None |
|
|
|
def encode(self, input: Dict[str, Tensor]) -> Dict[str, Any]: |
|
""" |
|
Overview: |
|
Encodes the input by passing through the encoder network and returns the latent codes. |
|
Arguments: |
|
- input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) and \ |
|
`action` (:obj:`torch.Tensor`), representing the observation and agent's action respectively. |
|
Returns: |
|
- outputs (:obj:`Dict`): Dict containing keywords ``mu`` (:obj:`torch.Tensor`), \ |
|
``log_var`` (:obj:`torch.Tensor`) and ``obs_encoding`` (:obj:`torch.Tensor`) \ |
|
representing latent codes. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``. |
|
- action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. |
|
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. |
|
- log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. |
|
- obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``. |
|
""" |
|
action_encoding = self.encode_action_head(input['action']) |
|
obs_encoding = self.encode_obs_head(input['obs']) |
|
|
|
input = obs_encoding * action_encoding |
|
result = self.encode_common(input) |
|
|
|
|
|
|
|
mu = self.encode_mu_head(result) |
|
log_var = self.encode_logvar_head(result) |
|
|
|
return {'mu': mu, 'log_var': log_var, 'obs_encoding': obs_encoding} |
|
|
|
def decode(self, z: Tensor, obs_encoding: Tensor) -> Dict[str, Any]: |
|
""" |
|
Overview: |
|
Maps the given latent action and obs_encoding onto the original action space. |
|
Arguments: |
|
- z (:obj:`torch.Tensor`): the sampled latent action |
|
- obs_encoding (:obj:`torch.Tensor`): observation encoding |
|
Returns: |
|
- outputs (:obj:`Dict`): DQN forward outputs, such as q_value. |
|
ReturnsKeys: |
|
- reconstruction_action (:obj:`torch.Tensor`): reconstruction_action. |
|
- predition_residual (:obj:`torch.Tensor`): predition_residual. |
|
Shapes: |
|
- z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` |
|
- obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim`` |
|
""" |
|
action_decoding = self.decode_action_head(torch.tanh(z)) |
|
action_obs_decoding = action_decoding * obs_encoding |
|
action_obs_decoding_tmp = self.decode_common(action_obs_decoding) |
|
|
|
reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp) |
|
predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp) |
|
predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp) |
|
return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual} |
|
|
|
def decode_with_obs(self, z: Tensor, obs: Tensor) -> Dict[str, Any]: |
|
""" |
|
Overview: |
|
Maps the given latent action and obs onto the original action space. |
|
Using the method self.encode_obs_head(obs) to get the obs_encoding. |
|
Arguments: |
|
- z (:obj:`torch.Tensor`): the sampled latent action |
|
- obs (:obj:`torch.Tensor`): observation |
|
Returns: |
|
- outputs (:obj:`Dict`): DQN forward outputs, such as q_value. |
|
ReturnsKeys: |
|
- reconstruction_action (:obj:`torch.Tensor`): the action reconstructed by VAE . |
|
- predition_residual (:obj:`torch.Tensor`): the observation predicted by VAE. |
|
Shapes: |
|
- z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` |
|
- obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``obs_shape`` |
|
""" |
|
obs_encoding = self.encode_obs_head(obs) |
|
|
|
action_decoding = self.decode_action_head(z) |
|
action_obs_decoding = action_decoding * obs_encoding |
|
action_obs_decoding_tmp = self.decode_common(action_obs_decoding) |
|
reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp) |
|
predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp) |
|
predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp) |
|
|
|
return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual} |
|
|
|
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: |
|
""" |
|
Overview: |
|
Reparameterization trick to sample from N(mu, var) from N(0,1). |
|
Arguments: |
|
- mu (:obj:`torch.Tensor`): Mean of the latent Gaussian |
|
- logvar (:obj:`torch.Tensor`): Standard deviation of the latent Gaussian |
|
Shapes: |
|
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size`` |
|
- logvar (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size`` |
|
""" |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
return eps * std + mu |
|
|
|
def forward(self, input: Dict[str, Tensor], **kwargs) -> dict: |
|
""" |
|
Overview: |
|
Encode the input, reparameterize `mu` and `log_var`, decode `obs_encoding`. |
|
Argumens: |
|
- input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) \ |
|
and `action` (:obj:`torch.Tensor`), representing the observation \ |
|
and agent's action respectively. |
|
Returns: |
|
- outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \ |
|
(:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \ |
|
``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \ |
|
``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). |
|
Shapes: |
|
- recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. |
|
- prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, \ |
|
where B is batch size and O is ``observation dim``. |
|
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. |
|
- log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. |
|
- z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` |
|
""" |
|
|
|
encode_output = self.encode(input) |
|
z = self.reparameterize(encode_output['mu'], encode_output['log_var']) |
|
decode_output = self.decode(z, encode_output['obs_encoding']) |
|
return { |
|
'recons_action': decode_output['reconstruction_action'], |
|
'prediction_residual': decode_output['predition_residual'], |
|
'input': input, |
|
'mu': encode_output['mu'], |
|
'log_var': encode_output['log_var'], |
|
'z': z |
|
} |
|
|
|
def loss_function(self, args: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]: |
|
""" |
|
Overview: |
|
Computes the VAE loss function. |
|
Arguments: |
|
- args (:obj:`Dict[str, Tensor]`): Dict containing keywords ``recons_action``, ``prediction_residual`` \ |
|
``original_action``, ``mu``, ``log_var`` and ``true_residual``. |
|
- kwargs (:obj:`Dict`): Dict containing keywords ``kld_weight`` and ``predict_weight``. |
|
Returns: |
|
- outputs (:obj:`Dict[str, Tensor]`): Dict containing different ``loss`` results, including ``loss``, \ |
|
``reconstruction_loss``, ``kld_loss``, ``predict_loss``. |
|
Shapes: |
|
- recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size \ |
|
and A is ``action dim``. |
|
- prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size \ |
|
and O is ``observation dim``. |
|
- original_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. |
|
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. |
|
- log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. |
|
- true_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``. |
|
""" |
|
recons_action = args['recons_action'] |
|
prediction_residual = args['prediction_residual'] |
|
original_action = args['original_action'] |
|
mu = args['mu'] |
|
log_var = args['log_var'] |
|
true_residual = args['true_residual'] |
|
|
|
kld_weight = kwargs['kld_weight'] |
|
predict_weight = kwargs['predict_weight'] |
|
|
|
recons_loss = F.mse_loss(recons_action, original_action) |
|
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) |
|
predict_loss = F.mse_loss(prediction_residual, true_residual) |
|
|
|
loss = recons_loss + kld_weight * kld_loss + predict_weight * predict_loss |
|
return {'loss': loss, 'reconstruction_loss': recons_loss, 'kld_loss': kld_loss, 'predict_loss': predict_loss} |
|
|