sgoodfriend's picture
A2C playing HalfCheetahBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
8edc5d6
raw
history blame
1.03 kB
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Distribution
class PiForward(NamedTuple):
pi: Distribution
logp_a: Optional[torch.Tensor]
entropy: Optional[torch.Tensor]
class Actor(nn.Module, ABC):
@abstractmethod
def forward(
self,
obs: torch.Tensor,
actions: Optional[torch.Tensor] = None,
action_masks: Optional[torch.Tensor] = None,
) -> PiForward:
...
def sample_weights(self, batch_size: int = 1) -> None:
pass
@property
@abstractmethod
def action_shape(self) -> Tuple[int, ...]:
...
def pi_forward(
distribution: Distribution, actions: Optional[torch.Tensor] = None
) -> PiForward:
logp_a = None
entropy = None
if actions is not None:
logp_a = distribution.log_prob(actions)
entropy = distribution.entropy()
return PiForward(distribution, logp_a, entropy)