sgoodfriend's picture
A2C playing HalfCheetahBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
8edc5d6
raw
history blame
3.81 kB
from typing import Dict, Optional, Tuple, Type
import numpy as np
import torch
import torch.nn as nn
from numpy.typing import NDArray
from torch.distributions import Distribution, constraints
from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
from rl_algo_impls.shared.encoder import EncoderOutDim
from rl_algo_impls.shared.module.utils import mlp
class GridnetDistribution(Distribution):
def __init__(
self,
map_size: int,
action_vec: NDArray[np.int64],
logits: torch.Tensor,
masks: torch.Tensor,
validate_args: Optional[bool] = None,
) -> None:
self.map_size = map_size
self.action_vec = action_vec
masks = masks.view(-1, masks.shape[-1])
split_masks = torch.split(masks, action_vec.tolist(), dim=1)
grid_logits = logits.reshape(-1, action_vec.sum())
split_logits = torch.split(grid_logits, action_vec.tolist(), dim=1)
self.categoricals = [
MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
for lg, m in zip(split_logits, split_masks)
]
batch_shape = logits.size()[:-1] if logits.ndimension() > 1 else torch.Size()
super().__init__(batch_shape=batch_shape, validate_args=validate_args)
def log_prob(self, action: torch.Tensor) -> torch.Tensor:
prob_stack = torch.stack(
[
c.log_prob(a)
for a, c in zip(action.view(-1, action.shape[-1]).T, self.categoricals)
],
dim=-1,
)
logprob = prob_stack.view(-1, self.map_size, len(self.action_vec))
return logprob.sum(dim=(1, 2))
def entropy(self) -> torch.Tensor:
ent = torch.stack([c.entropy() for c in self.categoricals], dim=-1)
ent = ent.view(-1, self.map_size, len(self.action_vec))
return ent.sum(dim=(1, 2))
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
s = torch.stack([c.sample(sample_shape) for c in self.categoricals], dim=-1)
return s.view(-1, self.map_size, len(self.action_vec))
@property
def mode(self) -> torch.Tensor:
m = torch.stack([c.mode for c in self.categoricals], dim=-1)
return m.view(-1, self.map_size, len(self.action_vec))
@property
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
# Constraints handled by child distributions in dist
return {}
class GridnetActorHead(Actor):
def __init__(
self,
map_size: int,
action_vec: NDArray[np.int64],
in_dim: EncoderOutDim,
hidden_sizes: Tuple[int, ...] = (32,),
activation: Type[nn.Module] = nn.ReLU,
init_layers_orthogonal: bool = True,
) -> None:
super().__init__()
self.map_size = map_size
self.action_vec = action_vec
assert isinstance(in_dim, int)
layer_sizes = (in_dim,) + hidden_sizes + (map_size * action_vec.sum(),)
self._fc = mlp(
layer_sizes,
activation,
init_layers_orthogonal=init_layers_orthogonal,
final_layer_gain=0.01,
)
def forward(
self,
obs: torch.Tensor,
actions: Optional[torch.Tensor] = None,
action_masks: Optional[torch.Tensor] = None,
) -> PiForward:
assert (
action_masks is not None
), f"No mask case unhandled in {self.__class__.__name__}"
logits = self._fc(obs)
pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
return pi_forward(pi, actions)
@property
def action_shape(self) -> Tuple[int, ...]:
return (self.map_size, len(self.action_vec))