File size: 11,962 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
"""Credit: Note the following vae model is modified from https://github.com/AntixK/PyTorch-VAE"""
import torch
from torch.nn import functional as F
from torch import nn
from abc import abstractmethod
from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple, Optional
from ding.utils.type_helper import Tensor
class VanillaVAE(nn.Module):
"""
Overview:
Implementation of Vanilla variational autoencoder for action reconstruction.
Interfaces:
``__init__``, ``encode``, ``decode``, ``decode_with_obs``, ``reparameterize``, \
``forward``, ``loss_function`` .
"""
def __init__(
self,
action_shape: int,
obs_shape: int,
latent_size: int,
hidden_dims: List = [256, 256],
**kwargs
) -> None:
super(VanillaVAE, self).__init__()
self.action_shape = action_shape
self.obs_shape = obs_shape
self.latent_size = latent_size
self.hidden_dims = hidden_dims
# Build Encoder
self.encode_action_head = nn.Sequential(nn.Linear(self.action_shape, hidden_dims[0]), nn.ReLU())
self.encode_obs_head = nn.Sequential(nn.Linear(self.obs_shape, hidden_dims[0]), nn.ReLU())
self.encode_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[1]), nn.ReLU())
self.encode_mu_head = nn.Linear(hidden_dims[1], latent_size)
self.encode_logvar_head = nn.Linear(hidden_dims[1], latent_size)
# Build Decoder
self.decode_action_head = nn.Sequential(nn.Linear(latent_size, hidden_dims[-1]), nn.ReLU())
self.decode_common = nn.Sequential(nn.Linear(hidden_dims[-1], hidden_dims[-2]), nn.ReLU())
# TODO(pu): tanh
self.decode_reconst_action_head = nn.Sequential(nn.Linear(hidden_dims[-2], self.action_shape), nn.Tanh())
# residual prediction
self.decode_prediction_head_layer1 = nn.Sequential(nn.Linear(hidden_dims[-2], hidden_dims[-2]), nn.ReLU())
self.decode_prediction_head_layer2 = nn.Linear(hidden_dims[-2], self.obs_shape)
self.obs_encoding = None
def encode(self, input: Dict[str, Tensor]) -> Dict[str, Any]:
"""
Overview:
Encodes the input by passing through the encoder network and returns the latent codes.
Arguments:
- input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) and \
`action` (:obj:`torch.Tensor`), representing the observation and agent's action respectively.
Returns:
- outputs (:obj:`Dict`): Dict containing keywords ``mu`` (:obj:`torch.Tensor`), \
``log_var`` (:obj:`torch.Tensor`) and ``obs_encoding`` (:obj:`torch.Tensor`) \
representing latent codes.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``.
- action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``.
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
- log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
- obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``.
"""
action_encoding = self.encode_action_head(input['action'])
obs_encoding = self.encode_obs_head(input['obs'])
# obs_encoding = self.condition_obs(input['obs']) # TODO(pu): using a different network
input = obs_encoding * action_encoding # TODO(pu): what about add, cat?
result = self.encode_common(input)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.encode_mu_head(result)
log_var = self.encode_logvar_head(result)
return {'mu': mu, 'log_var': log_var, 'obs_encoding': obs_encoding}
def decode(self, z: Tensor, obs_encoding: Tensor) -> Dict[str, Any]:
"""
Overview:
Maps the given latent action and obs_encoding onto the original action space.
Arguments:
- z (:obj:`torch.Tensor`): the sampled latent action
- obs_encoding (:obj:`torch.Tensor`): observation encoding
Returns:
- outputs (:obj:`Dict`): DQN forward outputs, such as q_value.
ReturnsKeys:
- reconstruction_action (:obj:`torch.Tensor`): reconstruction_action.
- predition_residual (:obj:`torch.Tensor`): predition_residual.
Shapes:
- z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size``
- obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``
"""
action_decoding = self.decode_action_head(torch.tanh(z)) # NOTE: tanh, here z is not bounded
action_obs_decoding = action_decoding * obs_encoding
action_obs_decoding_tmp = self.decode_common(action_obs_decoding)
reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp)
predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp)
predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp)
return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual}
def decode_with_obs(self, z: Tensor, obs: Tensor) -> Dict[str, Any]:
"""
Overview:
Maps the given latent action and obs onto the original action space.
Using the method self.encode_obs_head(obs) to get the obs_encoding.
Arguments:
- z (:obj:`torch.Tensor`): the sampled latent action
- obs (:obj:`torch.Tensor`): observation
Returns:
- outputs (:obj:`Dict`): DQN forward outputs, such as q_value.
ReturnsKeys:
- reconstruction_action (:obj:`torch.Tensor`): the action reconstructed by VAE .
- predition_residual (:obj:`torch.Tensor`): the observation predicted by VAE.
Shapes:
- z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size``
- obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``obs_shape``
"""
obs_encoding = self.encode_obs_head(obs)
# TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh
action_decoding = self.decode_action_head(z)
action_obs_decoding = action_decoding * obs_encoding
action_obs_decoding_tmp = self.decode_common(action_obs_decoding)
reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp)
predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp)
predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp)
return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual}
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Overview:
Reparameterization trick to sample from N(mu, var) from N(0,1).
Arguments:
- mu (:obj:`torch.Tensor`): Mean of the latent Gaussian
- logvar (:obj:`torch.Tensor`): Standard deviation of the latent Gaussian
Shapes:
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size``
- logvar (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size``
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Dict[str, Tensor], **kwargs) -> dict:
"""
Overview:
Encode the input, reparameterize `mu` and `log_var`, decode `obs_encoding`.
Argumens:
- input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) \
and `action` (:obj:`torch.Tensor`), representing the observation \
and agent's action respectively.
Returns:
- outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \
(:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \
``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \
``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`).
Shapes:
- recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``.
- prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, \
where B is batch size and O is ``observation dim``.
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
- log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
- z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size``
"""
encode_output = self.encode(input)
z = self.reparameterize(encode_output['mu'], encode_output['log_var'])
decode_output = self.decode(z, encode_output['obs_encoding'])
return {
'recons_action': decode_output['reconstruction_action'],
'prediction_residual': decode_output['predition_residual'],
'input': input,
'mu': encode_output['mu'],
'log_var': encode_output['log_var'],
'z': z
}
def loss_function(self, args: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
"""
Overview:
Computes the VAE loss function.
Arguments:
- args (:obj:`Dict[str, Tensor]`): Dict containing keywords ``recons_action``, ``prediction_residual`` \
``original_action``, ``mu``, ``log_var`` and ``true_residual``.
- kwargs (:obj:`Dict`): Dict containing keywords ``kld_weight`` and ``predict_weight``.
Returns:
- outputs (:obj:`Dict[str, Tensor]`): Dict containing different ``loss`` results, including ``loss``, \
``reconstruction_loss``, ``kld_loss``, ``predict_loss``.
Shapes:
- recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size \
and A is ``action dim``.
- prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size \
and O is ``observation dim``.
- original_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``.
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
- log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
- true_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``.
"""
recons_action = args['recons_action']
prediction_residual = args['prediction_residual']
original_action = args['original_action']
mu = args['mu']
log_var = args['log_var']
true_residual = args['true_residual']
kld_weight = kwargs['kld_weight']
predict_weight = kwargs['predict_weight']
recons_loss = F.mse_loss(recons_action, original_action)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
predict_loss = F.mse_loss(prediction_residual, true_residual)
loss = recons_loss + kld_weight * kld_loss + predict_weight * predict_loss
return {'loss': loss, 'reconstruction_loss': recons_loss, 'kld_loss': kld_loss, 'predict_loss': predict_loss}
|