from typing import Optional from einops import rearrange import torch import torch.nn as nn from risk_biased.models.cvae_params import CVAEParams from risk_biased.models.nn_blocks import ( MCG, MAB, MHB, SequenceEncoderLSTM, SequenceEncoderMLP, SequenceEncoderMaskedLSTM, ) from risk_biased.models.latent_distributions import AbstractLatentDistribution class BaseEncoderNN(nn.Module): """Base encoder neural network that defines the common functionality of encoders. It should not be used directly but rather extended to define specific encoders. Args: params: dataclass defining the necessary parameters num_steps: length of the input sequence """ def __init__( self, params: CVAEParams, latent_dim: int, num_steps: int, ) -> None: super().__init__() self.is_mlp_residual = params.is_mlp_residual self.num_hidden_layers = params.num_hidden_layers self.num_steps = params.num_steps self.num_steps_future = params.num_steps_future self.sequence_encoder_type = params.sequence_encoder_type self.state_dim = params.state_dim self.latent_dim = latent_dim self.hidden_dim = params.hidden_dim if params.sequence_encoder_type == "MLP": self._agent_encoder = SequenceEncoderMLP( params.state_dim, params.hidden_dim, params.num_hidden_layers, num_steps, params.is_mlp_residual, ) elif params.sequence_encoder_type == "LSTM": self._agent_encoder = SequenceEncoderLSTM( params.state_dim, params.hidden_dim ) elif params.sequence_encoder_type == "maskedLSTM": self._agent_encoder = SequenceEncoderMaskedLSTM( params.state_dim, params.hidden_dim ) if params.interaction_type == "Attention" or params.interaction_type == "MAB": self._interaction = MAB( params.hidden_dim, params.num_attention_heads, params.num_blocks ) elif ( params.interaction_type == "ContextGating" or params.interaction_type == "MCG" ): self._interaction = MCG( params.hidden_dim, params.mcg_dim_expansion, params.mcg_num_layers, params.num_blocks, params.is_mlp_residual, ) elif params.interaction_type == "Hybrid" or params.interaction_type == "MHB": self._interaction = MHB( params.hidden_dim, params.num_attention_heads, params.mcg_dim_expansion, params.mcg_num_layers, params.num_blocks, params.is_mlp_residual, ) else: self._interaction = lambda x, *args, **kwargs: x self._output_layer = nn.Linear(params.hidden_dim, self.latent_dim) def encode_agents(self, x: torch.Tensor, mask_x: torch.Tensor, *args, **kwargs): raise NotImplementedError def forward( self, x: torch.Tensor, mask_x: torch.Tensor, encoded_absolute: torch.Tensor, encoded_map: torch.Tensor, mask_map: torch.Tensor, y: Optional[torch.Tensor] = None, mask_y: Optional[torch.Tensor] = None, x_ego: Optional[torch.Tensor] = None, y_ego: Optional[torch.Tensor] = None, offset: Optional[torch.Tensor] = None, risk_level: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward function that encodes input tensors into an output tensor of dimension latent_dim. Args: x: (batch_size, num_agents, num_steps, state_dim) tensor of history mask_x: (batch_size, num_agents, num_steps) tensor of bool mask encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects mask_map: (batch_size, num_objects) tensor of bool mask y (optional): (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory. mask_y (optional): (batch_size, num_agents, num_steps_future) tensor of bool mask. Defaults to None. x_ego: (batch_size, 1, num_steps, state_dim) ego history y_ego: (batch_size, 1, num_steps_future, state_dim) ego future offset (optional): (batch_size, num_agents, state_dim) offset position from ego. risk_level (optional): (batch_size, num_agents) tensor of risk levels desired for future trajectories. Defaults to None. Returns: (batch_size, num_agents, latent_dim) output tensor """ h_agents = self.encode_agents( x=x, mask_x=mask_x, y=y, mask_y=mask_y, x_ego=x_ego, y_ego=y_ego, offset=offset, risk_level=risk_level, ) mask_agent = mask_x.any(-1) h_agents = self._interaction( h_agents, mask_agent, encoded_absolute, encoded_map, mask_map ) return self._output_layer(h_agents) class BiasedEncoderNN(BaseEncoderNN): """Biased encoder neural network that encodes past info and auxiliary input into a biased distribution over the latent space. Args: params: dataclass defining the necessary parameters num_steps: length of the input sequence """ def __init__( self, params: CVAEParams, latent_dim: int, num_steps: int, ) -> None: super().__init__(params, latent_dim, num_steps) self.condition_on_ego_future = params.condition_on_ego_future if params.sequence_encoder_type == "MLP": self._ego_encoder = SequenceEncoderMLP( params.state_dim, params.hidden_dim, params.num_hidden_layers, params.num_steps + params.num_steps_future * self.condition_on_ego_future, params.is_mlp_residual, ) elif params.sequence_encoder_type == "LSTM": self._ego_encoder = SequenceEncoderLSTM(params.state_dim, params.hidden_dim) elif params.sequence_encoder_type == "maskedLSTM": self._ego_encoder = SequenceEncoderMaskedLSTM( params.state_dim, params.hidden_dim ) self._auxiliary_encode = nn.Linear( params.hidden_dim + 1 + params.hidden_dim, params.hidden_dim ) def biased_parameters(self, recurse: bool = True): """Get the parameters to be optimized when training to bias.""" yield from self.parameters(recurse) def encode_agents( self, x: torch.Tensor, mask_x: torch.Tensor, *, x_ego: torch.Tensor, y_ego: torch.Tensor, offset: torch.Tensor, risk_level: torch.Tensor, **kwargs, ): """Encode agent input and auxiliary input into a feature vector. Args: x: (batch_size, num_agents, num_steps, state_dim) tensor of history mask_x: (batch_size, num_agents, num_steps) tensor of bool mask x_ego: (batch_size, 1, num_steps, state_dim) ego history y_ego: (batch_size, 1, num_steps_future, state_dim) ego future offset: (batch_size, num_agents, state_dim) offset position from ego. risk_level: (batch_size, num_agents) tensor of risk levels desired for future trajectories. Defaults to None. Returns: (batch_size, latent_dim) output tensor """ if self.condition_on_ego_future: ego_tensor = torch.cat([x_ego, y_ego], dim=-2) else: ego_tensor = x_ego risk_feature = ((risk_level - 0.5) * 10).exp().unsqueeze(-1) mask_ego = torch.ones( ego_tensor.shape[0], offset.shape[1], ego_tensor.shape[2], device=ego_tensor.device, ) batch_size, n_agents, dynamic_state_dim = offset.shape state_dim = ego_tensor.shape[-1] extended_offset = torch.cat( ( offset, torch.zeros( batch_size, n_agents, state_dim - dynamic_state_dim, device=offset.device, ), ), dim=-1, ).unsqueeze(-2) if extended_offset.shape[1] > 1: ego_encoded = self._ego_encoder( ego_tensor + extended_offset[:, :1] - extended_offset, mask_ego ) else: ego_encoded = self._ego_encoder(ego_tensor - extended_offset, mask_ego) auxiliary_input = torch.cat((risk_feature, ego_encoded), -1) h_agents = self._agent_encoder(x, mask_x) h_agents = torch.cat([h_agents, auxiliary_input], dim=-1) h_agents = self._auxiliary_encode(h_agents) return h_agents class InferenceEncoderNN(BaseEncoderNN): """Inference encoder neural network that encodes past info into the inference distribution over the latent space. Args: params: dataclass defining the necessary parameters num_steps: length of the input sequence """ def biaser_parameters(self, recurse: bool = True): yield from [] def encode_agents(self, x: torch.Tensor, mask_x: torch.Tensor, *args, **kwargs): h_agents = self._agent_encoder(x, mask_x) return h_agents class FutureEncoderNN(BaseEncoderNN): """Future encoder neural network that encodes past and future info into the future-conditioned distribution over the latent space. The future is not available at test time, this is only used for training. Args: params: dataclass defining the necessary parameters num_steps: length of the input sequence """ def biaser_parameters(self, recurse: bool = True): """The future encoder is not optimized when training to bias.""" yield from [] def encode_agents( self, x: torch.Tensor, mask_x: torch.Tensor, *, y: torch.Tensor, mask_y: torch.Tensor, **kwargs, ): """Encode agent input and future input into a feature vector. Args: x: (batch_size, num_agents, num_steps, state_dim) tensor of trajectory history mask_x: (batch_size, num_agents, num_steps) tensor of bool mask y: (batch_size, num_agents, num_steps_future, state_dim) future trajectory mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask """ mask_traj = torch.cat([mask_x, mask_y], dim=-1) h_agents = self._agent_encoder(torch.cat([x, y], dim=-2), mask_traj) return h_agents class CVAEEncoder(nn.Module): """Encoder architecture for conditional variational autoencoder Args: model: encoder neural network that transforms input tensors to an unsplitted latent output latent_distribution_creator: Class that creates a latent distribution class for the latent space. """ def __init__( self, model: BaseEncoderNN, latent_distribution_creator, ) -> None: super().__init__() self._model = model self.latent_dim = model.latent_dim self._latent_distribution_creator = latent_distribution_creator def biased_parameters(self, recurse: bool = True): yield from self._model.biased_parameters(recurse) def forward( self, x: torch.Tensor, mask_x: torch.Tensor, encoded_absolute: torch.Tensor, encoded_map: torch.Tensor, mask_map: torch.Tensor, y: Optional[torch.Tensor] = None, mask_y: Optional[torch.Tensor] = None, x_ego: Optional[torch.Tensor] = None, y_ego: Optional[torch.Tensor] = None, offset: Optional[torch.Tensor] = None, risk_level: Optional[torch.Tensor] = None, ) -> AbstractLatentDistribution: """Forward function that encodes input tensors into an output tensor of dimension latent_dim. Args: x: (batch_size, num_agents, num_steps, state_dim) tensor of history mask_x: (batch_size, num_agents, num_steps) tensor of bool mask encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects mask_map: (batch_size, num_objects) tensor of bool mask y (optional): (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory. mask_y (optional): (batch_size, num_agents, num_steps_future) tensor of bool mask. Defaults to None. x_ego (optional): (batch_size, 1, num_steps, state_dim) ego history y_ego (optional): (batch_size, 1, num_steps_future, state_dim) ego future offset (optional): (batch_size, num_agents, state_dim) offset position from ego. risk_level (optional): (batch_size, num_agents) tensor of risk levels desired for future trajectories. Defaults to None. Returns: Latent distribution representing the posterior over the latent variables given the input observations. """ latent_output = self._model( x=x, mask_x=mask_x, encoded_absolute=encoded_absolute, encoded_map=encoded_map, mask_map=mask_map, y=y, mask_y=mask_y, x_ego=x_ego, y_ego=y_ego, offset=offset, risk_level=risk_level, ) latent_distribution = self._latent_distribution_creator(latent_output) return latent_distribution