File size: 1,026 Bytes
76a55af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8edc5d6
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Distribution


class PiForward(NamedTuple):
    pi: Distribution
    logp_a: Optional[torch.Tensor]
    entropy: Optional[torch.Tensor]


class Actor(nn.Module, ABC):
    @abstractmethod
    def forward(
        self,
        obs: torch.Tensor,
        actions: Optional[torch.Tensor] = None,
        action_masks: Optional[torch.Tensor] = None,
    ) -> PiForward:
        ...

    def sample_weights(self, batch_size: int = 1) -> None:
        pass

    @property
    @abstractmethod
    def action_shape(self) -> Tuple[int, ...]:
        ...


def pi_forward(
    distribution: Distribution, actions: Optional[torch.Tensor] = None
) -> PiForward:
    logp_a = None
    entropy = None
    if actions is not None:
        logp_a = distribution.log_prob(actions)
        entropy = distribution.entropy()
    return PiForward(distribution, logp_a, entropy)