File size: 4,197 Bytes
b18ddcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e1c086
 
 
 
 
 
b18ddcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
import os
import torch
import torch.nn as nn

from abc import ABC, abstractmethod
from copy import deepcopy
from stable_baselines3.common.vec_env import unwrap_vec_normalize
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from typing import Dict, Optional, Type, TypeVar, Union

from rl_algo_impls.wrappers.normalize import NormalizeObservation, NormalizeReward
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs, find_wrapper

ACTIVATION: Dict[str, Type[nn.Module]] = {
    "tanh": nn.Tanh,
    "relu": nn.ReLU,
}

VEC_NORMALIZE_FILENAME = "vecnormalize.pkl"
MODEL_FILENAME = "model.pth"
NORMALIZE_OBSERVATION_FILENAME = "norm_obs.npz"
NORMALIZE_REWARD_FILENAME = "norm_reward.npz"

PolicySelf = TypeVar("PolicySelf", bound="Policy")


class Policy(nn.Module, ABC):
    @abstractmethod
    def __init__(self, env: VecEnv, **kwargs) -> None:
        super().__init__()
        self.env = env
        self.vec_normalize = unwrap_vec_normalize(env)
        self.norm_observation = find_wrapper(env, NormalizeObservation)
        self.norm_reward = find_wrapper(env, NormalizeReward)
        self.device = None

    def to(
        self: PolicySelf,
        device: Optional[torch.device] = None,
        dtype: Optional[Union[torch.dtype, str]] = None,
        non_blocking: bool = False,
    ) -> PolicySelf:
        super().to(device, dtype, non_blocking)
        self.device = device
        return self

    @abstractmethod
    def act(
        self,
        obs: VecEnvObs,
        deterministic: bool = True,
        action_masks: Optional[np.ndarray] = None,
    ) -> np.ndarray:
        ...

    def save(self, path: str) -> None:
        os.makedirs(path, exist_ok=True)

        if self.vec_normalize:
            self.vec_normalize.save(os.path.join(path, VEC_NORMALIZE_FILENAME))
        if self.norm_observation:
            self.norm_observation.save(
                os.path.join(path, NORMALIZE_OBSERVATION_FILENAME)
            )
        if self.norm_reward:
            self.norm_reward.save(os.path.join(path, NORMALIZE_REWARD_FILENAME))
        torch.save(
            self.state_dict(),
            os.path.join(path, MODEL_FILENAME),
        )

    def load(self, path: str) -> None:
        # VecNormalize load occurs in env.py
        self.load_state_dict(
            torch.load(os.path.join(path, MODEL_FILENAME), map_location=self.device)
        )
        if self.norm_observation:
            self.norm_observation.load(
                os.path.join(path, NORMALIZE_OBSERVATION_FILENAME)
            )
        if self.norm_reward:
            self.norm_reward.load(os.path.join(path, NORMALIZE_REWARD_FILENAME))

    def reset_noise(self) -> None:
        pass

    def _as_tensor(self, obs: VecEnvObs) -> torch.Tensor:
        assert isinstance(obs, np.ndarray)
        o = torch.as_tensor(obs)
        if self.device is not None:
            o = o.to(self.device)
        return o

    def num_trainable_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def num_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def sync_normalization(self, destination_env) -> None:
        current = destination_env
        while current != current.unwrapped:
            if isinstance(current, VecNormalize):
                assert self.vec_normalize
                current.ret_rms = deepcopy(self.vec_normalize.ret_rms)
                if hasattr(self.vec_normalize, "obs_rms"):
                    current.obs_rms = deepcopy(self.vec_normalize.obs_rms)
            elif isinstance(current, NormalizeObservation):
                assert self.norm_observation
                current.rms = deepcopy(self.norm_observation.rms)
            elif isinstance(current, NormalizeReward):
                assert self.norm_reward
                current.rms = deepcopy(self.norm_reward.rms)
            current = getattr(current, "venv", getattr(current, "env", current))
            if not current:
                raise AttributeError(
                    f"{type(current)} doesn't include env or venv attribute"
                )