import copy from datetime import datetime from typing import Union, Dict import gymnasium as gym import numpy as np from ding.envs import BaseEnvTimestep from ding.envs.common.common_function import affine_transform from ding.torch_utils import to_ndarray from ding.utils import ENV_REGISTRY from easydict import EasyDict from zoo.classic_control.cartpole.envs.cartpole_lightzero_env import CartPoleEnv @ENV_REGISTRY.register('pendulum_lightzero') class PendulumEnv(CartPoleEnv): """ LightZero version of the classic Pendulum environment. This class includes methods for resetting, closing, and stepping through the environment, as well as seeding for reproducibility, saving replay videos, and generating random actions. It also includes properties for accessing the observation space, action space, and reward space of the environment. """ @classmethod def default_config(cls: type) -> EasyDict: cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg config = dict( # (bool) Whether to use continuous action space continuous=True, # replay_path (str or None): The path to save the replay video. If None, the replay will not be saved. # Only effective when env_manager.type is 'base'. replay_path=None, # (bool) Whether to scale action into [-2, 2] act_scale=True, ) def __init__(self, cfg: dict) -> None: """ Initialize the environment with a configuration dictionary. Sets up spaces for observations, actions, and rewards. """ self._cfg = cfg self._act_scale = cfg.act_scale try: self._env = gym.make('Pendulum-v1', render_mode="rgb_array") except: self._env = gym.make('Pendulum-v0', render_mode="rgb_array") self._init_flag = False self._replay_path = cfg.replay_path self._continuous = cfg.get("continuous", True) self._observation_space = gym.spaces.Box( low=np.array([-1.0, -1.0, -8.0]), high=np.array([1.0, 1.0, 8.0]), shape=(3,), dtype=np.float32 ) if self._continuous: self._action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1,), dtype=np.float32) else: self.discrete_action_num = 11 self._action_space = gym.spaces.Discrete(self.discrete_action_num) self._action_space.seed(0) # default seed self._reward_space = gym.spaces.Box( low=-1 * (3.14 * 3.14 + 0.1 * 8 * 8 + 0.001 * 2 * 2), high=0.0, shape=(1,), dtype=np.float32 ) def reset(self) -> Dict[str, np.ndarray]: """ Reset the environment. If it hasn't been initialized yet, this method also handles that. It also handles seeding if necessary. Returns the first observation. """ if not self._init_flag: try: self._env = gym.make('Pendulum-v1', render_mode="rgb_array") except: self._env = gym.make('Pendulum-v0', render_mode="rgb_array") if self._replay_path is not None: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") video_name = f'{self._env.spec.id}-video-{timestamp}' self._env = gym.wrappers.RecordVideo( self._env, video_folder=self._replay_path, episode_trigger=lambda episode_id: True, name_prefix=video_name ) self._init_flag = True if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: np_seed = 100 * np.random.randint(1, 1000) self._seed = self._seed + np_seed self._action_space.seed(self._seed) obs, _ = self._env.reset(seed=self._seed) elif hasattr(self, '_seed'): self._action_space.seed(self._seed) obs, _ = self._env.reset(seed=self._seed) else: obs, _ = self._env.reset() obs = to_ndarray(obs).astype(np.float32) self._eval_episode_return = 0. if not self._continuous: action_mask = np.ones(self.discrete_action_num, 'int8') else: action_mask = None obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} return obs def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: """ Overview: Step the environment forward with the provided action. This method returns the next state of the environment (observation, reward, done flag, and info dictionary) encapsulated in a BaseEnvTimestep object. Arguments: - action (:obj:`Union[int, np.ndarray]`): The action to be performed in the environment. Returns: - timestep (:obj:`BaseEnvTimestep`): An object containing the new observation, reward, done flag, and info dictionary. .. note:: - If the environment requires discrete actions, they are converted to float actions in the range [-1, 1]. - If action scaling is enabled, continuous actions are scaled into the range [-2, 2]. - For each step, the cumulative reward (`_eval_episode_return`) is updated. - If the episode ends (done is True), the total reward for the episode is stored in the info dictionary under the key 'eval_episode_return'. - If the environment requires discrete actions, an action mask is created, otherwise, it's None. - Observations are returned in a dictionary format containing 'observation', 'action_mask', and 'to_play'. """ if isinstance(action, int): action = np.array(action) # if require discrete env, convert actions to [-1 ~ 1] float actions if not self._continuous: action = (action / (self.discrete_action_num - 1)) * 2 - 1 # scale the continous action into [-2, 2] if self._act_scale: action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high) obs, rew, terminated, truncated, info = self._env.step(action) done = terminated or truncated self._eval_episode_return += rew obs = to_ndarray(obs).astype(np.float32) # wrapped to be transferred to an array with shape (1,) rew = to_ndarray([rew]).astype(np.float32) if done: info['eval_episode_return'] = self._eval_episode_return if not self._continuous: action_mask = np.ones(self.discrete_action_num, 'int8') else: action_mask = None obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} return BaseEnvTimestep(obs, rew, done, info) def random_action(self) -> np.ndarray: """ Generate a random action using the action space's sample method. Returns a numpy array containing the action. """ if self._continuous: random_action = self.action_space.sample().astype(np.float32) else: random_action = self.action_space.sample() random_action = to_ndarray([random_action], dtype=np.int64) return random_action def __repr__(self) -> str: """ String representation of the environment. """ return "LightZero Pendulum Env({})".format(self._cfg.env_id)