jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
raw
history blame
15.6 kB
from einops import rearrange, repeat
import torch
import torch.nn as nn
from risk_biased.models.cvae_params import CVAEParams
from risk_biased.models.nn_blocks import (
MCG,
MAB,
MHB,
SequenceDecoderLSTM,
SequenceDecoderMLP,
SequenceEncoderLSTM,
SequenceEncoderMLP,
SequenceEncoderMaskedLSTM,
)
class DecoderNN(nn.Module):
"""Decoder neural network that decodes input tensors into a single output tensor.
It contains an interaction layer that (re-)compute the interactions between the agents in the scene.
This implies that a given latent sample for one agent will be affecting the predictions of the othe agents too.
Args:
params: dataclass defining the necessary parameters
"""
def __init__(
self,
params: CVAEParams,
) -> None:
super().__init__()
self.dt = params.dt
self.state_dim = params.state_dim
self.dynamic_state_dim = params.dynamic_state_dim
self.hidden_dim = params.hidden_dim
self.num_steps_future = params.num_steps_future
self.latent_dim = params.latent_dim
if params.sequence_encoder_type == "MLP":
self._agent_encoder_past = SequenceEncoderMLP(
params.state_dim,
params.hidden_dim,
params.num_hidden_layers,
params.num_steps,
params.is_mlp_residual,
)
elif params.sequence_encoder_type == "LSTM":
self._agent_encoder_past = SequenceEncoderLSTM(
params.state_dim, params.hidden_dim
)
elif params.sequence_encoder_type == "maskedLSTM":
self._agent_encoder_past = SequenceEncoderMaskedLSTM(
params.state_dim, params.hidden_dim
)
else:
raise RuntimeError(
f"Got sequence encoder type {params.sequence_decoder_type} but only knows one of: 'MLP', 'LSTM', 'maskedLSTM' "
)
self._combine_z_past = nn.Linear(
params.hidden_dim + params.latent_dim, params.hidden_dim
)
if params.interaction_type == "Attention" or params.interaction_type == "MAB":
self._interaction = MAB(
params.hidden_dim, params.num_attention_heads, params.num_blocks
)
elif (
params.interaction_type == "ContextGating"
or params.interaction_type == "MCG"
):
self._interaction = MCG(
params.hidden_dim,
params.mcg_dim_expansion,
params.mcg_num_layers,
params.num_blocks,
params.is_mlp_residual,
)
elif params.interaction_type == "Hybrid" or params.interaction_type == "MHB":
self._interaction = MHB(
params.hidden_dim,
params.num_attention_heads,
params.mcg_dim_expansion,
params.mcg_num_layers,
params.num_blocks,
params.is_mlp_residual,
)
else:
self._interaction = lambda x, *args, **kwargs: x
if params.sequence_decoder_type == "MLP":
self._decoder = SequenceDecoderMLP(
params.hidden_dim,
params.num_hidden_layers,
params.num_steps_future,
params.is_mlp_residual,
)
elif params.sequence_decoder_type == "LSTM":
self._decoder = SequenceDecoderLSTM(params.hidden_dim)
elif params.sequence_decoder_type == "maskedLSTM":
self._decoder = SequenceDecoderLSTM(params.hidden_dim)
else:
raise RuntimeError(
f"Got sequence decoder type {params.sequence_decoder_type} but only knows one of: 'MLP', 'LSTM', 'maskedLSTM' "
)
def forward(
self,
z_samples: torch.Tensor,
mask_z: torch.Tensor,
x: torch.Tensor,
mask_x: torch.Tensor,
encoded_absolute: torch.Tensor,
encoded_map: torch.Tensor,
mask_map: torch.Tensor,
) -> torch.Tensor:
"""Forward function that decodes input tensors into an output tensor of size
(batch_size, num_agents, (n_samples), num_steps_future, state_dim)
Args:
z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of history
mask_z: (batch_size, num_agents) tensor of bool mask
x: (batch_size, num_agents, num_steps, state_dim) tensor of history for all agents
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
mask_map: (batch_size, num_objects) tensor of bool mask
Returns:
(batch_size, num_agents, (n_samples), num_steps_future, state_dim) output tensor
"""
encoded_x = self._agent_encoder_past(x, mask_x)
squeeze_output_sample_dim = False
if z_samples.ndim == 3:
batch_size, num_agents, latent_dim = z_samples.shape
num_samples = 1
z_samples = rearrange(z_samples, "b a l -> b a () l")
squeeze_output_sample_dim = True
else:
batch_size, num_agents, num_samples, latent_dim = z_samples.shape
mask_z = repeat(mask_z, "b a -> (b s) a", s=num_samples)
mask_map = repeat(mask_map, "b o -> (b s) o", s=num_samples)
encoded_x = repeat(encoded_x, "b a l -> (b s) a l", s=num_samples)
encoded_absolute = repeat(
encoded_absolute, "b a l -> (b s) a l", s=num_samples
)
encoded_map = repeat(encoded_map, "b o l -> (b s) o l", s=num_samples)
z_samples = rearrange(z_samples, "b a s l -> (b s) a l")
h = self._combine_z_past(torch.cat([z_samples, encoded_x], dim=-1))
h = self._interaction(h, mask_z, encoded_absolute, encoded_map, mask_map)
h = self._decoder(h, self.num_steps_future)
if not squeeze_output_sample_dim:
h = rearrange(h, "(b s) a t l -> b a s t l", b=batch_size, s=num_samples)
return h
class CVAEAccelerationDecoder(nn.Module):
"""Decoder architecture for conditional variational autoencoder
Args:
model: decoder neural network that transforms input tensors to an output sequence
"""
def __init__(
self,
model: nn.Module,
) -> None:
super().__init__()
self._model = model
self._output_layer = nn.Linear(model.hidden_dim, 2)
def forward(
self,
z_samples: torch.Tensor,
mask_z: torch.Tensor,
x: torch.Tensor,
mask_x: torch.Tensor,
encoded_absolute: torch.Tensor,
encoded_map: torch.Tensor,
mask_map: torch.Tensor,
offset: torch.Tensor,
) -> torch.Tensor:
"""Forward function that decodes input tensors into an output tensor of size
(batch_size, num_agents, (n_samples), num_steps_future, state_dim=5)
It first predicts accelerations that are doubly integrated to produce the output
state sequence with positions angles and velocities (x, y, theta, vx, vy) or (x, y, vx, vy) or (x, y)
Args:
z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of history
mask_z: (batch_size, num_agents) tensor of bool mask
x: (batch_size, num_agents, num_steps, state_dim) tensor of history for all agents
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
mask_map: (batch_size, num_objects) tensor of bool mask
Returns:
(batch_size, num_agents, (n_samples), num_steps_future, state_dim) output tensor. Sample dimension
does not exist if z_samples is a 2D tensor.
"""
h = self._model(
z_samples, mask_z, x, mask_x, encoded_absolute, encoded_map, mask_map
)
h = self._output_layer(h)
dt = self._model.dt
initial_position = x[..., -1:, :2].clone()
# If shape is 5 it should be (x, y, angle, vx, vy)
if offset.shape[-1] == 5:
initial_velocity = offset[..., 3:5].clone().unsqueeze(-2)
# else if shape is 4 it should be (x, y, vx, vy)
elif offset.shape[-1] == 4:
initial_velocity = offset[..., 2:4].clone().unsqueeze(-2)
elif x.shape[-1] == 5:
initial_velocity = x[..., -1:, 3:5].clone()
elif x.shape[-1] == 4:
initial_velocity = x[..., -1:, 2:4].clone()
else:
initial_velocity = (x[..., -1:, :] - x[..., -2:-1, :]) / dt
output = torch.zeros(
(*h.shape[:-1], self._model.dynamic_state_dim), device=h.device
)
# There might be a sample dimension in the output tensor, then adapt the shape of initial position and velocity
if output.ndim == 5:
initial_position = initial_position.unsqueeze(-3)
initial_velocity = initial_velocity.unsqueeze(-3)
if self._model.dynamic_state_dim == 5:
output[..., 3:5] = h.cumsum(-2) * dt
output[..., :2] = (output[..., 3:5].clone() + initial_velocity).cumsum(
-2
) * dt + initial_position
output[..., 2] = torch.atan2(output[..., 4].clone(), output[..., 3].clone())
elif self._model.dynamic_state_dim == 4:
output[..., 2:4] = h.cumsum(-2) * dt
output[..., :2] = (output[..., 2:4].clone() + initial_velocity).cumsum(
-2
) * dt + initial_position
else:
velocity = h.cumsum(-2) * dt
output = (velocity.clone() + initial_velocity).cumsum(
-2
) * dt + initial_position
return output
class CVAEParametrizedDecoder(nn.Module):
"""Decoder architecture for conditional variational autoencoder
Args:
model: decoder neural network that transforms input tensors to an output sequence
"""
def __init__(
self,
model: nn.Module,
) -> None:
super().__init__()
self._model = model
self._order = 3
self._output_layer = nn.Linear(
model.hidden_dim * model.num_steps_future,
2 * self._order + model.num_steps_future,
)
def polynomial(self, x: torch.Tensor, params: torch.Tensor):
"""Polynomial function that takes a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future) and
a parameter tensor of shape (batch_size, num_agents, (n_samples), self._order*2) and returns a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future)
"""
h = x.clone()
squeeze = False
if h.ndim == 3:
h = h.unsqueeze(2)
params = params.unsqueeze(2)
squeeze = True
h = repeat(
h,
"batch agents samples sequence -> batch agents samples sequence two order",
order=self._order,
two=2,
).cumprod(-1)
h = h * params.view(*params.shape[:-1], 1, 2, self._order)
h = h.sum(-1)
if squeeze:
h = h.squeeze(2)
return h
def dpolynomial(self, x: torch.Tensor, params: torch.Tensor):
"""Derivative of the polynomial function that takes a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future) and
a parameter tensor of shape (batch_size, num_agents, (n_samples), self._order*2) and returns a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future)
"""
h = x.clone()
squeeze = False
if h.ndim == 3:
h = h.unsqueeze(2)
params = params.unsqueeze(2)
squeeze = True
h = repeat(
h,
"batch agents samples sequence -> batch agents samples sequence two order",
order=self._order - 1,
two=2,
)
h = torch.cat((torch.ones_like(h[..., :1]), h.cumprod(-1)), -1)
h = h * params.view(*params.shape[:-1], 1, 2, self._order)
h = h * torch.arange(self._order).view(*([1] * params.ndim), -1).to(x.device)
h = h.sum(-1)
if squeeze:
h = h.squeeze(2)
return h
def forward(
self,
z_samples: torch.Tensor,
mask_z: torch.Tensor,
x: torch.Tensor,
mask_x: torch.Tensor,
encoded_absolute: torch.Tensor,
encoded_map: torch.Tensor,
mask_map: torch.Tensor,
offset: torch.Tensor,
) -> torch.Tensor:
"""Forward function that decodes input tensors into an output tensor of size
(batch_size, num_agents, (n_samples), num_steps_future, state_dim=5)
It first predicts accelerations that are doubly integrated to produce the output
state sequence with positions angles and velocities (x, y, theta, vx, vy) or (x, y, vx, vy) or (x, y)
Args:
z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of history
mask_z: (batch_size, num_agents) tensor of bool mask
x: (batch_size, num_agents, num_steps, state_dim) tensor of history for all agents
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
mask_map: (batch_size, num_objects) tensor of bool mask
Returns:
(batch_size, num_agents, (n_samples), num_steps_future, state_dim) output tensor. Sample dimension
does not exist if z_samples is a 2D tensor.
"""
squeeze_output_sample_dim = z_samples.ndim == 3
batch_size = z_samples.shape[0]
h = self._model(
z_samples, mask_z, x, mask_x, encoded_absolute, encoded_map, mask_map
)
if squeeze_output_sample_dim:
h = rearrange(
h, "batch agents sequence features -> batch agents (sequence features)"
)
else:
h = rearrange(
h,
"(batch samples) agents sequence features -> batch agents samples (sequence features)",
batch=batch_size,
)
h = self._output_layer(h)
output = torch.zeros(
(
*h.shape[:-1],
self._model.num_steps_future,
self._model.dynamic_state_dim,
),
device=h.device,
)
params = h[..., : 2 * self._order]
dldt = torch.relu(h[..., 2 * self._order :])
distance = dldt.cumsum(-2)
output[..., :2] = self.polynomial(distance, params)
if self._model.dynamic_state_dim == 5:
output[..., 3:5] = dldt * self.dpolynomial(distance, params)
output[..., 2] = torch.atan2(output[..., 4].clone(), output[..., 3].clone())
elif self._model.dynamic_state_dim == 4:
output[..., 2:4] = dldt * self.dpolynomial(distance, params)
return output