File size: 3,674 Bytes
0e936e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Tuple, Type

import gym
import torch.nn as nn
from gym.spaces import Box, Discrete, MultiDiscrete

from rl_algo_impls.shared.actor.actor import Actor
from rl_algo_impls.shared.actor.categorical import CategoricalActorHead
from rl_algo_impls.shared.actor.gaussian import GaussianActorHead
from rl_algo_impls.shared.actor.gridnet import GridnetActorHead
from rl_algo_impls.shared.actor.gridnet_decoder import GridnetDecoder
from rl_algo_impls.shared.actor.multi_discrete import MultiDiscreteActorHead
from rl_algo_impls.shared.actor.state_dependent_noise import (
    StateDependentNoiseActorHead,
)
from rl_algo_impls.shared.encoder import EncoderOutDim


def actor_head(
    action_space: gym.Space,
    in_dim: EncoderOutDim,
    hidden_sizes: Tuple[int, ...],
    init_layers_orthogonal: bool,
    activation: Type[nn.Module],
    log_std_init: float = -0.5,
    use_sde: bool = False,
    full_std: bool = True,
    squash_output: bool = False,
    actor_head_style: str = "single",
) -> Actor:
    assert not use_sde or isinstance(
        action_space, Box
    ), "use_sde only valid if Box action_space"
    assert not squash_output or use_sde, "squash_output only valid if use_sde"
    if isinstance(action_space, Discrete):
        assert isinstance(in_dim, int)
        return CategoricalActorHead(
            action_space.n,  # type: ignore
            in_dim=in_dim,
            hidden_sizes=hidden_sizes,
            activation=activation,
            init_layers_orthogonal=init_layers_orthogonal,
        )
    elif isinstance(action_space, Box):
        assert isinstance(in_dim, int)
        if use_sde:
            return StateDependentNoiseActorHead(
                action_space.shape[0],  # type: ignore
                in_dim=in_dim,
                hidden_sizes=hidden_sizes,
                activation=activation,
                init_layers_orthogonal=init_layers_orthogonal,
                log_std_init=log_std_init,
                full_std=full_std,
                squash_output=squash_output,
            )
        else:
            return GaussianActorHead(
                action_space.shape[0],  # type: ignore
                in_dim=in_dim,
                hidden_sizes=hidden_sizes,
                activation=activation,
                init_layers_orthogonal=init_layers_orthogonal,
                log_std_init=log_std_init,
            )
    elif isinstance(action_space, MultiDiscrete):
        if actor_head_style == "single":
            return MultiDiscreteActorHead(
                action_space.nvec,  # type: ignore
                in_dim=in_dim,
                hidden_sizes=hidden_sizes,
                activation=activation,
                init_layers_orthogonal=init_layers_orthogonal,
            )
        elif actor_head_style == "gridnet":
            return GridnetActorHead(
                action_space.nvec[0],  # type: ignore
                action_space.nvec[1:],  # type: ignore
                in_dim=in_dim,
                hidden_sizes=hidden_sizes,
                activation=activation,
                init_layers_orthogonal=init_layers_orthogonal,
            )
        elif actor_head_style == "gridnet_decoder":
            return GridnetDecoder(
                action_space.nvec[0],  # type: ignore
                action_space.nvec[1:],  # type: ignore
                in_dim=in_dim,
                activation=activation,
                init_layers_orthogonal=init_layers_orthogonal,
            )
        else:
            raise ValueError(f"Doesn't support actor_head_style {actor_head_style}")
    else:
        raise ValueError(f"Unsupported action space: {action_space}")