Spaces:
Running
Running
from __future__ import annotations | |
from abc import ABC, abstractmethod | |
from typing import Optional, Tuple | |
from numpy import isin | |
import torch | |
from risk_biased.mpc_planner.planner_cost import TrackingCost | |
from risk_biased.utils.cost import BaseCostTorch | |
from risk_biased.utils.risk import AbstractMonteCarloRiskEstimator | |
def get_rotation_matrix(angle, device): | |
c = torch.cos(angle) | |
s = torch.sin(angle) | |
rot_matrix = torch.stack( | |
(torch.stack((c, s), -1), torch.stack((-s, c), -1)), -1 | |
).to(device) | |
return rot_matrix | |
class AbstractState(ABC): | |
""" | |
State representation using an underlying tensor. Position, Velocity, and Angle can be accessed. | |
""" | |
def position(self) -> torch.Tensor: | |
"""Extract position information from the state tensor | |
Returns: | |
position_tensor of size (..., 2) | |
""" | |
def velocity(self) -> torch.Tensor: | |
"""Extract velocity information from the state tensor | |
Returns: | |
velocity_tensor of size (..., 2) | |
""" | |
def angle(self) -> torch.Tensor: | |
"""Extract velocity information from the state tensor | |
Returns: | |
velocity_tensor of size (..., 1) | |
""" | |
def get_states(self, dim: int) -> torch.Tensor: | |
"""Return the underlying states tensor with dim 2, 4 or 5 ([x, y], [x, y, vx, vy], or [x, y, angle, vx, vy]).""" | |
def rotate(self, angle: float, in_place: bool) -> AbstractState: | |
"""Rotate the state by the given angle | |
Args: | |
angle: in radiants | |
in_place: wether to change the object itself or return a rotated copy | |
Returns: | |
rotated self or rotated copy of self | |
""" | |
def translate(self, translation: torch.Tensor, in_place: bool) -> AbstractState: | |
"""Translate the state by the given tranlation | |
Args: | |
translation: translation vector in 2 dimensions | |
in_place: wether to change the object itself or return a rotated copy | |
""" | |
# Define overloading operators to behave as a tensor for some operations | |
def __getitem__(self, key) -> AbstractState: | |
""" | |
Use get item on the underlying tensor to get the item at the given key. | |
Allways returns a velocity state so that if the underlying time sequence is reduced to one step, the velocity is still accessible. | |
""" | |
if isinstance(key, int): | |
key = (key, Ellipsis, slice(None, None, None)) | |
elif Ellipsis not in key: | |
key = (*key, Ellipsis, slice(None, None, None)) | |
else: | |
key = (*key, slice(None, None, None)) | |
return to_state( | |
torch.cat( | |
( | |
self.position[key], | |
self.velocity[key], | |
), | |
dim=-1, | |
), | |
self.dt, | |
) | |
def shape(self): | |
return self._states.shape[:-1] | |
def to_state(in_tensor: torch.Tensor, dt: float) -> AbstractState: | |
if in_tensor.shape[-1] == 2: | |
return PositionSequenceState(in_tensor, dt) | |
elif in_tensor.shape[-1] == 4: | |
return PositionVelocityState(in_tensor, dt) | |
else: | |
assert in_tensor.shape[-1] > 4 | |
return PositionAngleVelocityState(in_tensor, dt) | |
class PositionSequenceState(AbstractState): | |
""" | |
State representation with an underlying tensor defining only positions. | |
""" | |
def __init__(self, states: torch.Tensor, dt: float) -> None: | |
super().__init__() | |
assert ( | |
states.shape[-1] == 2 | |
) # Check that the input tensor defines only the position | |
assert ( | |
states.ndim > 1 and states.shape[-2] > 1 | |
) # Check that the input tensor defines a sequence of positions (otherwise velocity cannot be computed) | |
self.dt = dt | |
self._states = states.clone() | |
def position(self) -> torch.Tensor: | |
return self._states | |
def velocity(self) -> torch.Tensor: | |
vel = (self._states[..., 1:, :] - self._states[..., :-1, :]) / self.dt | |
vel = torch.cat((vel[..., 0:1, :], vel), dim=-2) | |
return vel.clone() | |
def angle(self) -> torch.Tensor: | |
vel = self.velocity | |
angle = torch.arctan2(vel[..., 1:2], vel[..., 0:1]) | |
return angle | |
def get_states(self, dim: int = 2) -> torch.Tensor: | |
if dim == 2: | |
return self._states.clone() | |
elif dim == 4: | |
return torch.cat((self._states.clone(), self.velocity), dim=-1) | |
elif dim == 5: | |
return torch.cat((self._states.clone(), self.angle, self.velocity), dim=-1) | |
else: | |
raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}") | |
def rotate(self, angle: float, in_place: bool = False) -> PositionSequenceState: | |
"""Rotate the state by the given angle in radiants""" | |
rot_matrix = get_rotation_matrix(angle, self._states.device) | |
if in_place: | |
self._states = (rot_matrix @ self._states.unsqueeze(-1)).squeeze(-1) | |
return self | |
else: | |
return to_state( | |
(rot_matrix @ self._states.unsqueeze(-1).clone()).squeeze(-1), self.dt | |
) | |
def translate( | |
self, translation: torch.Tensor, in_place: bool = False | |
) -> PositionSequenceState: | |
"""Translate the state by the given tranlation""" | |
if in_place: | |
self._states[..., :2] += translation.expand_as(self._states[..., :2]) | |
return self | |
else: | |
return to_state( | |
self._states[..., :2].clone() | |
+ translation.expand_as(self._states[..., :2]), | |
self.dt, | |
) | |
class PositionVelocityState(AbstractState): | |
""" | |
State representation with an underlying tensor defining position and velocity. | |
""" | |
def __init__(self, states: torch.Tensor, dt) -> None: | |
super().__init__() | |
assert states.shape[-1] == 4 | |
self._states = states.clone() | |
self.dt = dt | |
def position(self) -> torch.Tensor: | |
return self._states[..., :2] | |
def velocity(self) -> torch.Tensor: | |
return self._states[..., 2:4] | |
def angle(self) -> torch.Tensor: | |
vel = self.velocity | |
angle = torch.arctan2(vel[..., 1:2], vel[..., 0:1]) | |
return angle | |
def get_states(self, dim: int = 4) -> torch.Tensor: | |
if dim == 2: | |
return self._states[..., :2].clone() | |
elif dim == 4: | |
return self._states.clone() | |
elif dim == 5: | |
return torch.cat( | |
( | |
self._states[..., :2].clone(), | |
self.angle, | |
self._states[..., 2:].clone(), | |
), | |
dim=-1, | |
) | |
else: | |
raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}") | |
def rotate( | |
self, angle: torch.Tensor, in_place: bool = False | |
) -> PositionVelocityState: | |
"""Rotate the state by the given angle in radiants""" | |
rot_matrix = get_rotation_matrix(angle, self._states.device) | |
rotated_pos = (rot_matrix @ self.position.unsqueeze(-1)).squeeze(-1) | |
rotated_vel = (rot_matrix @ self.velocity.unsqueeze(-1)).squeeze(-1) | |
if in_place: | |
self._states = torch.cat((rotated_pos, rotated_vel), dim=-1) | |
return self | |
else: | |
return to_state(torch.cat((rotated_pos, rotated_vel), dim=-1), self.dt) | |
def translate( | |
self, translation: torch.Tensor, in_place: bool = False | |
) -> PositionVelocityState: | |
"""Translate the state by the given tranlation""" | |
if in_place: | |
self._states[..., :2] += translation.expand_as(self._states[..., :2]) | |
return self | |
else: | |
return to_state( | |
torch.cat( | |
( | |
self._states[..., :2].clone() | |
+ translation.expand_as(self._states[..., :2]), | |
self._states[..., 2:].clone(), | |
), | |
dim=-1, | |
), | |
self.dt, | |
) | |
class PositionAngleVelocityState(AbstractState): | |
""" | |
State representation with an underlying tensor representing position angle and velocity. | |
""" | |
def __init__(self, states: torch.Tensor, dt: float) -> None: | |
super().__init__() | |
assert states.shape[-1] == 5 | |
self._states = states.clone() | |
self.dt = dt | |
def position(self) -> torch.Tensor: | |
return self._states[..., :2].clone() | |
def velocity(self) -> torch.Tensor: | |
return self._states[..., 3:5].clone() | |
def angle(self) -> torch.Tensor: | |
return self._states[..., 2:3].clone() | |
def get_states(self, dim: int = 5) -> torch.Tensor: | |
if dim == 2: | |
return self._states[..., :2].clone() | |
elif dim == 4: | |
return torch.cat( | |
(self._states[..., :2].clone(), self._states[..., 3:].clone()), dim=-1 | |
) | |
elif dim == 5: | |
return self._states.clone() | |
else: | |
raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}") | |
def rotate( | |
self, angle: float, in_place: bool = False | |
) -> PositionAngleVelocityState: | |
"""Rotate the state by the given angle in radiants""" | |
rot_matrix = get_rotation_matrix(angle, self._states.device) | |
rotated_pos = (rot_matrix @ self.position.unsqueeze(-1)).squeeze(-1) | |
rotated_angle = self.angle + angle | |
rotated_vel = (rot_matrix @ self.velocity.unsqueeze(-1)).squeeze(-1) | |
if in_place: | |
self._states = torch.cat(rotated_pos, rotated_angle, rotated_vel, -1) | |
return self | |
else: | |
return to_state( | |
torch.cat(rotated_pos, rotated_angle, rotated_vel, -1), self.dt | |
) | |
def translate( | |
self, translation: torch.Tensor, in_place: bool = False | |
) -> PositionAngleVelocityState: | |
"""Translate the state by the given tranlation""" | |
if in_place: | |
self._states[..., :2] += translation.expand_as(self._states[..., :2]) | |
return self | |
else: | |
return to_state( | |
torch.cat( | |
( | |
self._states[..., :2] | |
+ translation.expand_as(self._states[..., :2]), | |
self._states[..., 2:], | |
), | |
dim=-1, | |
), | |
self.dt, | |
) | |
def get_interaction_cost( | |
ego_state_future: AbstractState, | |
ado_state_future_samples: AbstractState, | |
interaction_cost_function: BaseCostTorch, | |
) -> torch.Tensor: | |
"""Computes interaction cost samples from predicted ado future trajectories and a batch of ego | |
future trajectories | |
Args: | |
ego_state_future: ((num_control_samples), num_agents, num_steps_future) ego state future | |
future trajectory | |
ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future) | |
predicted ado state trajectory samples | |
interaction_cost_function: interaction cost function between ego and (stochastic) ado | |
dt: time differential between two discrete timesteps in seconds | |
Returns: | |
(num_control_samples, num_agents, num_prediction_samples) interaction cost tensor | |
""" | |
if len(ego_state_future.shape) == 2: | |
x_ego = ego_state_future.position.unsqueeze(0) | |
v_ego = ego_state_future.velocity.unsqueeze(0) | |
else: | |
x_ego = ego_state_future.position | |
v_ego = ego_state_future.velocity | |
num_control_samples = ego_state_future.shape[0] | |
ado_position_future_samples = ado_state_future_samples.position.unsqueeze(0).expand( | |
num_control_samples, -1, -1, -1, -1 | |
) | |
v_samples = ado_state_future_samples.velocity.unsqueeze(0).expand( | |
num_control_samples, -1, -1, -1, -1 | |
) | |
interaction_cost, _ = interaction_cost_function( | |
x1=x_ego.unsqueeze(1), | |
x2=ado_position_future_samples, | |
v1=v_ego.unsqueeze(1), | |
v2=v_samples, | |
) | |
return interaction_cost.permute(0, 2, 1) | |
def evaluate_risk( | |
risk_level: float, | |
cost: torch.Tensor, | |
weights: torch.Tensor, | |
risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None, | |
) -> torch.Tensor: | |
"""Returns a risk tensor given costs and optionally a risk level | |
Args: | |
risk_level (optional): a risk-level float. If 0.0, risk-neutral expectation will be | |
returned. Defaults to 0.0. | |
cost: (num_control_samples, num_agents, num_prediction_samples) cost tensor | |
weights: (num_control_samples, num_agents, num_prediction_samples) probability weight of the cost tensor | |
risk_estimator (optional): a Monte Carlo risk estimator. Defaults to None. | |
Returns: | |
(num_control_samples, num_agents) risk tensor | |
""" | |
num_control_samples, num_agents, _ = cost.shape | |
if risk_level == 0.0: | |
risk = cost.mean(dim=-1) | |
else: | |
assert risk_estimator is not None, "no risk estimator is specified" | |
risk = risk_estimator( | |
risk_level * torch.ones(num_control_samples, num_agents), | |
cost, | |
weights=weights, | |
) | |
return risk | |
def evaluate_control_sequence( | |
control_sequence: torch.Tensor, | |
dynamics_model, | |
ego_state_history: AbstractState, | |
ego_state_target_trajectory: AbstractState, | |
ado_state_future_samples: AbstractState, | |
sample_weights: torch.Tensor, | |
interaction_cost_function: BaseCostTorch, | |
tracking_cost_function: TrackingCost, | |
risk_level: float = 0.0, | |
risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None, | |
) -> Tuple[float, float]: | |
"""Returns the risk and tracking cost evaluation of the given control sequence | |
Args: | |
control_sequence: (num_steps_future, control_dim) tensor of control sequence | |
dynamics_model: dynamics model for control | |
ego_state_target_trajectory: (num_steps_future) tensor of ego target | |
state trajectory | |
ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future) | |
of predicted ado trajectory samples states | |
sample_weights: (num_prediction_samples, num_agents) tensor of probability weights of the samples | |
intraction_cost_function: interaction cost function between ego and (stochastic) ado | |
tracking_cost_function: deterministic tracking cost that does not involve ado | |
risk_level: risk_level (optional): a risk-level float. If 0.0, risk-neutral expectation | |
is used. Defaults to 0.0. | |
risk_estimator (optional): a Monte Carlo risk estimator. Defaults to None. | |
Returns: | |
tuple of (interaction risk, tracking_cost) | |
""" | |
ego_state_current = ego_state_history[..., -1] | |
ego_state_future = dynamics_model.simulate(ego_state_current, control_sequence) | |
# state starts with x, y, angle, vx, vy | |
tracking_cost = tracking_cost_function( | |
ego_state_future.position, | |
ego_state_target_trajectory.position, | |
ego_state_target_trajectory.velocity, | |
) | |
interaction_cost = get_interaction_cost( | |
ego_state_future, | |
ado_state_future_samples, | |
interaction_cost_function, | |
) | |
interaction_risk = evaluate_risk( | |
risk_level, | |
interaction_cost, | |
sample_weights.permute(1, 0).unsqueeze(0).expand_as(interaction_cost), | |
risk_estimator, | |
) | |
# TODO: averaging over agents but we might want to reduce a different way | |
return (interaction_risk.mean().item(), tracking_cost.mean().item()) | |