jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
raw
history blame
3.1 kB
from dataclasses import dataclass
from mmcv import Config
@dataclass
class CVAEParams:
"""
state_dim: Dimension of the state at each time step.
map_state_dim: Dimension of the map point features at each position.
num_steps: Number of time steps in the past trajectory input.
num_steps_future: Number of time steps in the future trajectory output.
latent_dim: Dimension of the latent space
hidden_dim: Dimension of the hidden layers
num_hidden_layers: Number of layers for each model, (encoder, decoder)
is_mlp_residual: Set to True to add linear transformation of the input to output of the MLP
interaction_type: Wether to use MCG, MAB, or MHB to handle interactions
num_attention_heads: Number of attention heads to use in MHA blocks
mcg_dim_expansion: Dimension expansion factor for the MCG global interaction space
mcg_num_layers: Number of layers for the MLP MCG blocks
num_blocks: Number of interaction blocks to use
sequence_encoder_type: Type of sequence encoder maskedLSTM, LSTM, or MLP
sequence_decoder_type: Type of sequence decoder maskedLSTM, LSTM, or MLP
condition_on_ego_future: Wether to condition the biasing with the ego future or only the ego past
latent_regularization: Weight of the latent regularization loss
"""
dt: float
state_dim: int
dynamic_state_dim: int
map_state_dim: int
max_size_lane: int
num_steps: int
num_steps_future: int
latent_dim: int
hidden_dim: int
num_hidden_layers: int
is_mlp_residual: bool
interaction_type: int
num_attention_heads: int
mcg_dim_expansion: int
mcg_num_layers: int
num_blocks: int
sequence_encoder_type: str
sequence_decoder_type: str
condition_on_ego_future: bool
latent_regularization: float
risk_assymetry_factor: float
num_vq: int
latent_distribution: str
@staticmethod
def from_config(cfg: Config):
return CVAEParams(
dt=cfg.dt,
state_dim=cfg.state_dim,
dynamic_state_dim=cfg.dynamic_state_dim,
map_state_dim=cfg.map_state_dim,
max_size_lane=cfg.max_size_lane,
num_steps=cfg.num_steps,
num_steps_future=cfg.num_steps_future,
latent_dim=cfg.latent_dim,
hidden_dim=cfg.hidden_dim,
num_hidden_layers=cfg.num_hidden_layers,
is_mlp_residual=cfg.is_mlp_residual,
interaction_type=cfg.interaction_type,
mcg_dim_expansion=cfg.mcg_dim_expansion,
mcg_num_layers=cfg.mcg_num_layers,
num_blocks=cfg.num_blocks,
num_attention_heads=cfg.num_attention_heads,
sequence_encoder_type=cfg.sequence_encoder_type,
sequence_decoder_type=cfg.sequence_decoder_type,
condition_on_ego_future=cfg.condition_on_ego_future,
latent_regularization=cfg.latent_regularization,
risk_assymetry_factor=cfg.risk_assymetry_factor,
num_vq=cfg.num_vq,
latent_distribution=cfg.latent_distribution,
)