gomoku / LightZero /lzero /envs /wrappers /action_discretization_env_wrapper.py
zjowowen's picture
init space
079c32c
raw
history blame
3.68 kB
from itertools import product
import gym
import numpy as np
from easydict import EasyDict
from ding.envs import BaseEnvTimestep
from ding.torch_utils import to_ndarray
from ding.utils import ENV_WRAPPER_REGISTRY
@ENV_WRAPPER_REGISTRY.register('action_discretization_env_wrapper')
class ActionDiscretizationEnvWrapper(gym.Wrapper):
"""
Overview:
The modified environment with manually discretized action space. For each dimension, equally dividing the
original continuous action into ``each_dim_disc_size`` bins and using their Cartesian product to obtain
handcrafted discrete actions.
Interface:
``__init__``, ``reset``, ``step``
Properties:
- env (:obj:`gym.Env`): the environment to wrap.
"""
def __init__(self, env: gym.Env, cfg: EasyDict) -> None:
"""
Overview:
Initialize ``self.`` See ``help(type(self))`` for accurate signature; \
setup the properties according to running mean and std.
Arguments:
- env (:obj:`gym.Env`): the environment to wrap.
"""
super().__init__(env)
assert 'is_train' in cfg, '`is_train` flag must set in the config of env'
self.is_train = cfg.is_train
self.cfg = cfg
self.env_name = cfg.env_name
self.continuous = cfg.continuous
def reset(self, **kwargs):
"""
Overview:
Resets the state of the environment and reset properties.
Arguments:
- kwargs (:obj:`Dict`): Reset with this key argumets
Returns:
- observation (:obj:`Any`): New observation after reset
"""
obs = self.env.reset(**kwargs)
self._raw_action_space = self.env.action_space
if self.cfg.manually_discretization:
# disc_to_cont: transform discrete action index to original continuous action
self.m = self._raw_action_space.shape[0]
self.n = self.cfg.each_dim_disc_size
self.K = self.n ** self.m
self.disc_to_cont = list(product(*[list(range(self.n)) for dim in range(self.m)]))
# the modified discrete action space
self._action_space = gym.spaces.Discrete(self.K)
return obs
def step(self, action):
"""
Overview:
Step the environment with the given action. Repeat action, sum reward, \
and update ``data_count``, and also update the ``self.rms`` property \
once after integrating with the input ``action``.
Arguments:
- action (:obj:`Any`): the given action to step with.
Returns:
- ``self.observation(observation)`` : normalized observation after the \
input action and updated ``self.rms``
- reward (:obj:`Any`) : amount of reward returned after previous action
- done (:obj:`Bool`) : whether the episode has ended, in which case further \
step() calls will return undefined results
- info (:obj:`Dict`) : contains auxiliary diagnostic information (helpful \
for debugging, and sometimes learning)
"""
if self.cfg.manually_discretization:
# disc_to_cont: transform discrete action index to original continuous action
action = [-1 + 2 / self.n * k for k in self.disc_to_cont[int(action)]]
action = to_ndarray(action)
# The core original env step.
obs, rew, done, info = self.env.step(action)
return BaseEnvTimestep(obs, rew, done, info)
def __repr__(self) -> str:
return "Action Discretization Env."