File size: 16,040 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional, Tuple
from numpy import isin

import torch

from risk_biased.mpc_planner.planner_cost import TrackingCost
from risk_biased.utils.cost import BaseCostTorch
from risk_biased.utils.risk import AbstractMonteCarloRiskEstimator


def get_rotation_matrix(angle, device):
    c = torch.cos(angle)
    s = torch.sin(angle)
    rot_matrix = torch.stack(
        (torch.stack((c, s), -1), torch.stack((-s, c), -1)), -1
    ).to(device)
    return rot_matrix


class AbstractState(ABC):
    """
    State representation using an underlying tensor. Position, Velocity, and Angle can be accessed.
    """

    @property
    @abstractmethod
    def position(self) -> torch.Tensor:
        """Extract position information from the state tensor

        Returns:
            position_tensor of size (..., 2)
        """

    @property
    @abstractmethod
    def velocity(self) -> torch.Tensor:
        """Extract velocity information from the state tensor

        Returns:
            velocity_tensor of size (..., 2)
        """

    @property
    @abstractmethod
    def angle(self) -> torch.Tensor:
        """Extract velocity information from the state tensor

        Returns:
            velocity_tensor of size (..., 1)
        """

    @abstractmethod
    def get_states(self, dim: int) -> torch.Tensor:
        """Return the underlying states tensor with dim 2, 4 or 5 ([x, y], [x, y, vx, vy], or [x, y, angle, vx, vy])."""

    @abstractmethod
    def rotate(self, angle: float, in_place: bool) -> AbstractState:
        """Rotate the state by the given angle
        Args:
            angle: in radiants
            in_place: wether to change the object itself or return a rotated copy
        Returns:
            rotated self or rotated copy of self
        """

    @abstractmethod
    def translate(self, translation: torch.Tensor, in_place: bool) -> AbstractState:
        """Translate the state by the given tranlation
        Args:
            translation: translation vector in 2 dimensions
            in_place: wether to change the object itself or return a rotated copy
        """

    # Define overloading operators to behave as a tensor for some operations
    def __getitem__(self, key) -> AbstractState:
        """
        Use get item on the underlying tensor to get the item at the given key.
        Allways returns a velocity state so that if the underlying time sequence is reduced to one step, the velocity is still accessible.
        """
        if isinstance(key, int):
            key = (key, Ellipsis, slice(None, None, None))
        elif Ellipsis not in key:
            key = (*key, Ellipsis, slice(None, None, None))
        else:
            key = (*key, slice(None, None, None))

        return to_state(
            torch.cat(
                (
                    self.position[key],
                    self.velocity[key],
                ),
                dim=-1,
            ),
            self.dt,
        )

    @property
    def shape(self):
        return self._states.shape[:-1]


def to_state(in_tensor: torch.Tensor, dt: float) -> AbstractState:
    if in_tensor.shape[-1] == 2:
        return PositionSequenceState(in_tensor, dt)
    elif in_tensor.shape[-1] == 4:
        return PositionVelocityState(in_tensor, dt)
    else:
        assert in_tensor.shape[-1] > 4
        return PositionAngleVelocityState(in_tensor, dt)


class PositionSequenceState(AbstractState):
    """
    State representation with an underlying tensor defining only positions.
    """

    def __init__(self, states: torch.Tensor, dt: float) -> None:
        super().__init__()
        assert (
            states.shape[-1] == 2
        )  # Check that the input tensor defines only the position
        assert (
            states.ndim > 1 and states.shape[-2] > 1
        )  # Check that the input tensor defines a sequence of positions (otherwise velocity cannot be computed)
        self.dt = dt
        self._states = states.clone()

    @property
    def position(self) -> torch.Tensor:
        return self._states

    @property
    def velocity(self) -> torch.Tensor:
        vel = (self._states[..., 1:, :] - self._states[..., :-1, :]) / self.dt
        vel = torch.cat((vel[..., 0:1, :], vel), dim=-2)
        return vel.clone()

    @property
    def angle(self) -> torch.Tensor:
        vel = self.velocity
        angle = torch.arctan2(vel[..., 1:2], vel[..., 0:1])
        return angle

    def get_states(self, dim: int = 2) -> torch.Tensor:
        if dim == 2:
            return self._states.clone()
        elif dim == 4:
            return torch.cat((self._states.clone(), self.velocity), dim=-1)
        elif dim == 5:
            return torch.cat((self._states.clone(), self.angle, self.velocity), dim=-1)
        else:
            raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}")

    def rotate(self, angle: float, in_place: bool = False) -> PositionSequenceState:
        """Rotate the state by the given angle in radiants"""
        rot_matrix = get_rotation_matrix(angle, self._states.device)
        if in_place:
            self._states = (rot_matrix @ self._states.unsqueeze(-1)).squeeze(-1)
            return self
        else:
            return to_state(
                (rot_matrix @ self._states.unsqueeze(-1).clone()).squeeze(-1), self.dt
            )

    def translate(
        self, translation: torch.Tensor, in_place: bool = False
    ) -> PositionSequenceState:
        """Translate the state by the given tranlation"""
        if in_place:
            self._states[..., :2] += translation.expand_as(self._states[..., :2])
            return self
        else:
            return to_state(
                self._states[..., :2].clone()
                + translation.expand_as(self._states[..., :2]),
                self.dt,
            )


class PositionVelocityState(AbstractState):
    """
    State representation with an underlying tensor defining position and velocity.
    """

    def __init__(self, states: torch.Tensor, dt) -> None:
        super().__init__()
        assert states.shape[-1] == 4
        self._states = states.clone()
        self.dt = dt

    @property
    def position(self) -> torch.Tensor:
        return self._states[..., :2]

    @property
    def velocity(self) -> torch.Tensor:
        return self._states[..., 2:4]

    @property
    def angle(self) -> torch.Tensor:
        vel = self.velocity
        angle = torch.arctan2(vel[..., 1:2], vel[..., 0:1])
        return angle

    def get_states(self, dim: int = 4) -> torch.Tensor:
        if dim == 2:
            return self._states[..., :2].clone()
        elif dim == 4:
            return self._states.clone()
        elif dim == 5:
            return torch.cat(
                (
                    self._states[..., :2].clone(),
                    self.angle,
                    self._states[..., 2:].clone(),
                ),
                dim=-1,
            )
        else:
            raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}")

    def rotate(
        self, angle: torch.Tensor, in_place: bool = False
    ) -> PositionVelocityState:
        """Rotate the state by the given angle in radiants"""
        rot_matrix = get_rotation_matrix(angle, self._states.device)
        rotated_pos = (rot_matrix @ self.position.unsqueeze(-1)).squeeze(-1)
        rotated_vel = (rot_matrix @ self.velocity.unsqueeze(-1)).squeeze(-1)
        if in_place:
            self._states = torch.cat((rotated_pos, rotated_vel), dim=-1)
            return self
        else:
            return to_state(torch.cat((rotated_pos, rotated_vel), dim=-1), self.dt)

    def translate(
        self, translation: torch.Tensor, in_place: bool = False
    ) -> PositionVelocityState:
        """Translate the state by the given tranlation"""
        if in_place:
            self._states[..., :2] += translation.expand_as(self._states[..., :2])
            return self
        else:
            return to_state(
                torch.cat(
                    (
                        self._states[..., :2].clone()
                        + translation.expand_as(self._states[..., :2]),
                        self._states[..., 2:].clone(),
                    ),
                    dim=-1,
                ),
                self.dt,
            )


class PositionAngleVelocityState(AbstractState):
    """
    State representation with an underlying tensor representing position angle and velocity.
    """

    def __init__(self, states: torch.Tensor, dt: float) -> None:
        super().__init__()
        assert states.shape[-1] == 5
        self._states = states.clone()
        self.dt = dt

    @property
    def position(self) -> torch.Tensor:
        return self._states[..., :2].clone()

    @property
    def velocity(self) -> torch.Tensor:
        return self._states[..., 3:5].clone()

    @property
    def angle(self) -> torch.Tensor:
        return self._states[..., 2:3].clone()

    def get_states(self, dim: int = 5) -> torch.Tensor:
        if dim == 2:
            return self._states[..., :2].clone()
        elif dim == 4:
            return torch.cat(
                (self._states[..., :2].clone(), self._states[..., 3:].clone()), dim=-1
            )
        elif dim == 5:
            return self._states.clone()
        else:
            raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}")

    def rotate(
        self, angle: float, in_place: bool = False
    ) -> PositionAngleVelocityState:
        """Rotate the state by the given angle in radiants"""
        rot_matrix = get_rotation_matrix(angle, self._states.device)
        rotated_pos = (rot_matrix @ self.position.unsqueeze(-1)).squeeze(-1)
        rotated_angle = self.angle + angle
        rotated_vel = (rot_matrix @ self.velocity.unsqueeze(-1)).squeeze(-1)
        if in_place:
            self._states = torch.cat(rotated_pos, rotated_angle, rotated_vel, -1)
            return self
        else:
            return to_state(
                torch.cat(rotated_pos, rotated_angle, rotated_vel, -1), self.dt
            )

    def translate(
        self, translation: torch.Tensor, in_place: bool = False
    ) -> PositionAngleVelocityState:
        """Translate the state by the given tranlation"""
        if in_place:
            self._states[..., :2] += translation.expand_as(self._states[..., :2])
            return self
        else:
            return to_state(
                torch.cat(
                    (
                        self._states[..., :2]
                        + translation.expand_as(self._states[..., :2]),
                        self._states[..., 2:],
                    ),
                    dim=-1,
                ),
                self.dt,
            )


def get_interaction_cost(
    ego_state_future: AbstractState,
    ado_state_future_samples: AbstractState,
    interaction_cost_function: BaseCostTorch,
) -> torch.Tensor:
    """Computes interaction cost samples from predicted ado future trajectories and a batch of ego
    future trajectories

    Args:
        ego_state_future: ((num_control_samples), num_agents, num_steps_future) ego state future
            future trajectory
        ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future)
            predicted ado state trajectory samples
        interaction_cost_function: interaction cost function between ego and (stochastic) ado
        dt: time differential between two discrete timesteps in seconds

    Returns:
        (num_control_samples, num_agents, num_prediction_samples) interaction cost tensor
    """
    if len(ego_state_future.shape) == 2:
        x_ego = ego_state_future.position.unsqueeze(0)
        v_ego = ego_state_future.velocity.unsqueeze(0)
    else:
        x_ego = ego_state_future.position
        v_ego = ego_state_future.velocity

    num_control_samples = ego_state_future.shape[0]
    ado_position_future_samples = ado_state_future_samples.position.unsqueeze(0).expand(
        num_control_samples, -1, -1, -1, -1
    )

    v_samples = ado_state_future_samples.velocity.unsqueeze(0).expand(
        num_control_samples, -1, -1, -1, -1
    )

    interaction_cost, _ = interaction_cost_function(
        x1=x_ego.unsqueeze(1),
        x2=ado_position_future_samples,
        v1=v_ego.unsqueeze(1),
        v2=v_samples,
    )
    return interaction_cost.permute(0, 2, 1)


def evaluate_risk(
    risk_level: float,
    cost: torch.Tensor,
    weights: torch.Tensor,
    risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None,
) -> torch.Tensor:
    """Returns a risk tensor given costs and optionally a risk level

    Args:
        risk_level (optional): a risk-level float. If 0.0, risk-neutral expectation will be
          returned. Defaults to 0.0.
        cost: (num_control_samples, num_agents, num_prediction_samples) cost tensor
        weights: (num_control_samples, num_agents, num_prediction_samples) probability weight of the cost tensor
        risk_estimator (optional): a Monte Carlo risk estimator. Defaults to None.

    Returns:
        (num_control_samples, num_agents) risk tensor
    """
    num_control_samples, num_agents, _ = cost.shape

    if risk_level == 0.0:
        risk = cost.mean(dim=-1)
    else:
        assert risk_estimator is not None, "no risk estimator is specified"
        risk = risk_estimator(
            risk_level * torch.ones(num_control_samples, num_agents),
            cost,
            weights=weights,
        )
    return risk


def evaluate_control_sequence(
    control_sequence: torch.Tensor,
    dynamics_model,
    ego_state_history: AbstractState,
    ego_state_target_trajectory: AbstractState,
    ado_state_future_samples: AbstractState,
    sample_weights: torch.Tensor,
    interaction_cost_function: BaseCostTorch,
    tracking_cost_function: TrackingCost,
    risk_level: float = 0.0,
    risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None,
) -> Tuple[float, float]:
    """Returns the risk and tracking cost evaluation of the given control sequence

    Args:
        control_sequence: (num_steps_future, control_dim) tensor of control sequence
        dynamics_model: dynamics model for control
        ego_state_target_trajectory: (num_steps_future) tensor of ego target
          state trajectory
        ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future)
          of predicted ado trajectory samples states
        sample_weights: (num_prediction_samples, num_agents) tensor of probability weights of the samples
        intraction_cost_function: interaction cost function between ego and (stochastic) ado
        tracking_cost_function: deterministic tracking cost that does not involve ado
        risk_level: risk_level (optional): a risk-level float. If 0.0, risk-neutral expectation
          is used. Defaults to 0.0.
        risk_estimator (optional): a Monte Carlo risk estimator. Defaults to None.

    Returns:
        tuple of (interaction risk, tracking_cost)
    """
    ego_state_current = ego_state_history[..., -1]
    ego_state_future = dynamics_model.simulate(ego_state_current, control_sequence)
    # state starts with x, y, angle, vx, vy
    tracking_cost = tracking_cost_function(
        ego_state_future.position,
        ego_state_target_trajectory.position,
        ego_state_target_trajectory.velocity,
    )

    interaction_cost = get_interaction_cost(
        ego_state_future,
        ado_state_future_samples,
        interaction_cost_function,
    )

    interaction_risk = evaluate_risk(
        risk_level,
        interaction_cost,
        sample_weights.permute(1, 0).unsqueeze(0).expand_as(interaction_cost),
        risk_estimator,
    )

    # TODO: averaging over agents but we might want to reduce a different way
    return (interaction_risk.mean().item(), tracking_cost.mean().item())