File size: 6,902 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from abc import ABC, abstractmethod
from typing import Any, List, Tuple
import gym
import copy
from easydict import EasyDict
from collections import namedtuple
from ding.utils import import_module, ENV_REGISTRY

BaseEnvTimestep = namedtuple('BaseEnvTimestep', ['obs', 'reward', 'done', 'info'])


# for solving multiple inheritance metaclass conflict between gym and ABC
class FinalMeta(type(ABC), type(gym.Env)):
    pass


class BaseEnv(gym.Env, ABC, metaclass=FinalMeta):
    """
    Overview:
        Basic environment class, extended from ``gym.Env``
    Interface:
        ``__init__``, ``reset``, ``close``, ``step``, ``random_action``, ``create_collector_env_cfg``, \
        ``create_evaluator_env_cfg``, ``enable_save_replay``
    """

    @abstractmethod
    def __init__(self, cfg: dict) -> None:
        """
        Overview:
            Lazy init, only related arguments will be initialized in ``__init__`` method, and the concrete \
            env will be initialized the first time ``reset`` method is called.
        Arguments:
            - cfg (:obj:`dict`): Environment configuration in dict type.
        """
        raise NotImplementedError

    @abstractmethod
    def reset(self) -> Any:
        """
        Overview:
            Reset the env to an initial state and returns an initial observation.
        Returns:
            - obs (:obj:`Any`): Initial observation after reset.
        """
        raise NotImplementedError

    @abstractmethod
    def close(self) -> None:
        """
        Overview:
            Close env and all the related resources, it should be called after the usage of env instance.
        """
        raise NotImplementedError

    @abstractmethod
    def step(self, action: Any) -> 'BaseEnv.timestep':
        """
        Overview:
            Run one timestep of the environment's dynamics/simulation.
        Arguments:
            - action (:obj:`Any`): The ``action`` input to step with.
        Returns:
            - timestep (:obj:`BaseEnv.timestep`): The result timestep of env executing one step.
        """
        raise NotImplementedError

    @abstractmethod
    def seed(self, seed: int) -> None:
        """
        Overview:
            Set the seed for this env's random number generator(s).
        Arguments:
            - seed (:obj:`Any`): Random seed.
        """
        raise NotImplementedError

    @abstractmethod
    def __repr__(self) -> str:
        """
        Overview:
            Return the information string of this env instance.
        Returns:
            - info (:obj:`str`): Information of this env instance, like type and arguments.
        """
        raise NotImplementedError

    @staticmethod
    def create_collector_env_cfg(cfg: dict) -> List[dict]:
        """
        Overview:
            Return a list of all of the environment from input config, used in env manager \
            (a series of vectorized env), and this method is mainly responsible for envs collecting data.
        Arguments:
            - cfg (:obj:`dict`): Original input env config, which needs to be transformed into the type of creating \
                env instance actually and generated the corresponding number of configurations.
        Returns:
            - env_cfg_list (:obj:`List[dict]`): List of ``cfg`` including all the config collector envs.

        .. note::
            Elements(env config) in collector_env_cfg/evaluator_env_cfg can be different, such as server ip and port.
        """
        collector_env_num = cfg.pop('collector_env_num')
        return [cfg for _ in range(collector_env_num)]

    @staticmethod
    def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
        """
        Overview:
            Return a list of all of the environment from input config, used in env manager \
            (a series of vectorized env), and this method is mainly responsible for envs evaluating performance.
        Arguments:
            - cfg (:obj:`dict`): Original input env config, which needs to be transformed into the type of creating \
                env instance actually and generated the corresponding number of configurations.
        Returns:
            - env_cfg_list (:obj:`List[dict]`): List of ``cfg`` including all the config evaluator envs.
        """
        evaluator_env_num = cfg.pop('evaluator_env_num')
        return [cfg for _ in range(evaluator_env_num)]

    # optional method
    def enable_save_replay(self, replay_path: str) -> None:
        """
        Overview:
            Save replay file in the given path, and this method need to be self-implemented by each env class.
        Arguments:
            - replay_path (:obj:`str`): The path to save replay file.
        """
        raise NotImplementedError

    # optional method
    def random_action(self) -> Any:
        """
        Overview:
            Return random action generated from the original action space, usually it is convenient for test.
        Returns:
            - random_action (:obj:`Any`): Action generated randomly.
        """
        pass


def get_vec_env_setting(cfg: dict, collect: bool = True, eval_: bool = True) -> Tuple[type, List[dict], List[dict]]:
    """
    Overview:
        Get vectorized env setting (env_fn, collector_env_cfg, evaluator_env_cfg).
    Arguments:
        - cfg (:obj:`dict`): Original input env config in user config, such as ``cfg.env``.
    Returns:
        - env_fn (:obj:`type`): Callable object, call it with proper arguments and then get a new env instance.
        - collector_env_cfg (:obj:`List[dict]`): A list contains the config of collecting data envs.
        - evaluator_env_cfg (:obj:`List[dict]`): A list contains the config of evaluation envs.

    .. note::
        Elements (env config) in collector_env_cfg/evaluator_env_cfg can be different, such as server ip and port.

    """
    import_module(cfg.get('import_names', []))
    env_fn = ENV_REGISTRY.get(cfg.type)
    collector_env_cfg = env_fn.create_collector_env_cfg(cfg) if collect else None
    evaluator_env_cfg = env_fn.create_evaluator_env_cfg(cfg) if eval_ else None
    return env_fn, collector_env_cfg, evaluator_env_cfg


def get_env_cls(cfg: EasyDict) -> type:
    """
    Overview:
        Get the env class by correspondng module of ``cfg`` and return the callable class.
    Arguments:
        - cfg (:obj:`dict`): Original input env config in user config, such as ``cfg.env``.
    Returns:
        - env_cls_type (:obj:`type`): Env module as the corresponding callable class type.
    """
    import_module(cfg.get('import_names', []))
    return ENV_REGISTRY.get(cfg.type)


def create_model_env(cfg: EasyDict) -> Any:
    """
    Overview:
        Create model env, which is used in model-based RL.
    """
    cfg = copy.deepcopy(cfg)
    model_env_fn = get_env_cls(cfg)
    cfg.pop('import_names')
    cfg.pop('type')
    return model_env_fn(**cfg)