File size: 3,810 Bytes
76a55af
 
 
 
 
 
 
 
8edc5d6
76a55af
 
8edc5d6
76a55af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8edc5d6
76a55af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8edc5d6
76a55af
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from typing import Dict, Optional, Tuple, Type

import numpy as np
import torch
import torch.nn as nn
from numpy.typing import NDArray
from torch.distributions import Distribution, constraints

from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
from rl_algo_impls.shared.encoder import EncoderOutDim
from rl_algo_impls.shared.module.utils import mlp


class GridnetDistribution(Distribution):
    def __init__(
        self,
        map_size: int,
        action_vec: NDArray[np.int64],
        logits: torch.Tensor,
        masks: torch.Tensor,
        validate_args: Optional[bool] = None,
    ) -> None:
        self.map_size = map_size
        self.action_vec = action_vec

        masks = masks.view(-1, masks.shape[-1])
        split_masks = torch.split(masks, action_vec.tolist(), dim=1)

        grid_logits = logits.reshape(-1, action_vec.sum())
        split_logits = torch.split(grid_logits, action_vec.tolist(), dim=1)
        self.categoricals = [
            MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
            for lg, m in zip(split_logits, split_masks)
        ]

        batch_shape = logits.size()[:-1] if logits.ndimension() > 1 else torch.Size()
        super().__init__(batch_shape=batch_shape, validate_args=validate_args)

    def log_prob(self, action: torch.Tensor) -> torch.Tensor:
        prob_stack = torch.stack(
            [
                c.log_prob(a)
                for a, c in zip(action.view(-1, action.shape[-1]).T, self.categoricals)
            ],
            dim=-1,
        )
        logprob = prob_stack.view(-1, self.map_size, len(self.action_vec))
        return logprob.sum(dim=(1, 2))

    def entropy(self) -> torch.Tensor:
        ent = torch.stack([c.entropy() for c in self.categoricals], dim=-1)
        ent = ent.view(-1, self.map_size, len(self.action_vec))
        return ent.sum(dim=(1, 2))

    def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
        s = torch.stack([c.sample(sample_shape) for c in self.categoricals], dim=-1)
        return s.view(-1, self.map_size, len(self.action_vec))

    @property
    def mode(self) -> torch.Tensor:
        m = torch.stack([c.mode for c in self.categoricals], dim=-1)
        return m.view(-1, self.map_size, len(self.action_vec))

    @property
    def arg_constraints(self) -> Dict[str, constraints.Constraint]:
        # Constraints handled by child distributions in dist
        return {}


class GridnetActorHead(Actor):
    def __init__(
        self,
        map_size: int,
        action_vec: NDArray[np.int64],
        in_dim: EncoderOutDim,
        hidden_sizes: Tuple[int, ...] = (32,),
        activation: Type[nn.Module] = nn.ReLU,
        init_layers_orthogonal: bool = True,
    ) -> None:
        super().__init__()
        self.map_size = map_size
        self.action_vec = action_vec
        assert isinstance(in_dim, int)
        layer_sizes = (in_dim,) + hidden_sizes + (map_size * action_vec.sum(),)
        self._fc = mlp(
            layer_sizes,
            activation,
            init_layers_orthogonal=init_layers_orthogonal,
            final_layer_gain=0.01,
        )

    def forward(
        self,
        obs: torch.Tensor,
        actions: Optional[torch.Tensor] = None,
        action_masks: Optional[torch.Tensor] = None,
    ) -> PiForward:
        assert (
            action_masks is not None
        ), f"No mask case unhandled in {self.__class__.__name__}"
        logits = self._fc(obs)
        pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
        return pi_forward(pi, actions)

    @property
    def action_shape(self) -> Tuple[int, ...]:
        return (self.map_size, len(self.action_vec))