File size: 1,254 Bytes
5d5d65a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gym
import torch as th
import torch.nn as nn

from gym.spaces import Discrete
from typing import Optional, Sequence, Type

from shared.module.feature_extractor import FeatureExtractor
from shared.module.module import mlp


class QNetwork(nn.Module):
    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        hidden_sizes: Sequence[int] = [],
        activation: Type[nn.Module] = nn.ReLU,  # Used by stable-baselines3
        cnn_feature_dim: int = 512,
        cnn_style: str = "nature",
        cnn_layers_init_orthogonal: Optional[bool] = None,
    ) -> None:
        super().__init__()
        assert isinstance(action_space, Discrete)
        self._feature_extractor = FeatureExtractor(
            observation_space,
            activation,
            cnn_feature_dim=cnn_feature_dim,
            cnn_style=cnn_style,
            cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
        )
        layer_sizes = (
            (self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
        )
        self._fc = mlp(layer_sizes, activation)

    def forward(self, obs: th.Tensor) -> th.Tensor:
        x = self._feature_extractor(obs)
        return self._fc(x)