File size: 7,078 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
from tabnanny import check
from typing import Any, Callable, List, Tuple
import numpy as np
from collections.abc import Sequence
from easydict import EasyDict
from ding.envs.env import BaseEnv, BaseEnvTimestep
from ding.envs.env.tests import DemoEnv
# from dizoo.atari.envs import AtariEnv
def check_space_dtype(env: BaseEnv) -> None:
print("== 0. Test obs/act/rew space's dtype")
env.reset()
for name, space in zip(['obs', 'act', 'rew'], [env.observation_space, env.action_space, env.reward_space]):
if 'float' in repr(space.dtype):
assert space.dtype == np.float32, "If float, then must be np.float32, but get {} for {} space".format(
space.dtype, name
)
if 'int' in repr(space.dtype):
assert space.dtype == np.int64, "If int, then must be np.int64, but get {} for {} space".format(
space.dtype, name
)
# Util function
def check_array_space(ndarray, space, name) -> bool:
if isinstance(ndarray, np.ndarray):
# print("{}'s type should be np.ndarray".format(name))
assert ndarray.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format(
name, ndarray.dtype, space.dtype
)
assert ndarray.shape == space.shape, "{}'s shape is {}, but requires {}".format(
name, ndarray.shape, space.shape
)
assert (space.low <= ndarray).all() and (ndarray <= space.high).all(
), "{}'s value is {}, but requires in range ({},{})".format(name, ndarray, space.low, space.high)
elif isinstance(ndarray, Sequence):
for i in range(len(ndarray)):
try:
check_array_space(ndarray[i], space[i], name)
except AssertionError as e:
print("The following error happens at {}-th index".format(i))
raise e
elif isinstance(ndarray, dict):
for k in ndarray.keys():
try:
check_array_space(ndarray[k], space[k], name)
except AssertionError as e:
print("The following error happens at key {}".format(k))
raise e
else:
raise TypeError(
"Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}".format(type(ndarray))
)
def check_reset(env: BaseEnv) -> None:
print('== 1. Test reset method')
obs = env.reset()
check_array_space(obs, env.observation_space, 'obs')
def check_step(env: BaseEnv) -> None:
done_times = 0
print('== 2. Test step method')
_ = env.reset()
if hasattr(env, "random_action"):
random_action = env.random_action()
else:
random_action = env.action_space.sample()
while True:
obs, rew, done, info = env.step(random_action)
for ndarray, space, name in zip([obs, rew], [env.observation_space, env.reward_space], ['obs', 'rew']):
check_array_space(ndarray, space, name)
if done:
assert 'eval_episode_return' in info, "info dict should have 'eval_episode_return' key."
done_times += 1
_ = env.reset()
if done_times == 3:
break
# Util function
def check_different_memory(array1, array2, step_times) -> None:
assert type(array1) == type(
array2
), "In step times {}, obs_last_frame({}) and obs_this_frame({}) are not of the same type".format(
step_times, type(array1), type(array2)
)
if isinstance(array1, np.ndarray):
assert id(array1) != id(
array2
), "In step times {}, obs_last_frame and obs_this_frame are the same np.ndarray".format(step_times)
elif isinstance(array1, Sequence):
assert len(array1) == len(
array2
), "In step times {}, obs_last_frame({}) and obs_this_frame({}) have different sequence lengths".format(
step_times, len(array1), len(array2)
)
for i in range(len(array1)):
try:
check_different_memory(array1[i], array2[i], step_times)
except AssertionError as e:
print("The following error happens at {}-th index".format(i))
raise e
elif isinstance(array1, dict):
assert array1.keys() == array2.keys(), "In step times {}, obs_last_frame({}) and obs_this_frame({}) have \
different dict keys".format(step_times, array1.keys(), array2.keys())
for k in array1.keys():
try:
check_different_memory(array1[k], array2[k], step_times)
except AssertionError as e:
print("The following error happens at key {}".format(k))
raise e
else:
raise TypeError(
"Input array should be np.ndarray or list/dict of np.ndarray, but found {} and {}".format(
type(array1), type(array2)
)
)
def check_obs_deepcopy(env: BaseEnv) -> None:
step_times = 0
print('== 3. Test observation deepcopy')
obs_1 = env.reset()
if hasattr(env, "random_action"):
random_action = env.random_action()
else:
random_action = env.action_space.sample()
while True:
step_times += 1
obs_2, _, done, _ = env.step(random_action)
check_different_memory(obs_1, obs_2, step_times)
obs_1 = obs_2
if done:
break
def check_all(env: BaseEnv) -> None:
check_space_dtype(env)
check_reset(env)
check_step(env)
check_obs_deepcopy(env)
def demonstrate_correct_procedure(env_fn: Callable) -> None:
print('== 4. Demonstrate the correct procudures')
done_times = 0
# Init the env.
env = env_fn({})
# Lazy init. The real env is not initialized until `reset` method is called
assert not hasattr(env, "_env")
# Must set seed before `reset` method is called.
env.seed(4)
assert env._seed == 4
# Reset the env. The real env is initialized here.
obs = env.reset()
while True:
# Using the policy to get the action from obs. But here we use `random_action` instead.
action = env.random_action()
obs, rew, done, info = env.step(action)
if done:
assert 'eval_episode_return' in info
done_times += 1
obs = env.reset()
# Seed will not change unless `seed` method is called again.
assert env._seed == 4
if done_times == 3:
break
if __name__ == "__main__":
'''
# Moethods `check_*` are for user to check whether their implemented env obeys DI-engine's rules.
# You can replace `AtariEnv` with your own env.
atari_env = AtariEnv(EasyDict(env_id='PongNoFrameskip-v4', frame_stack=4, is_train=False))
check_reset(atari_env)
check_step(atari_env)
check_obs_deepcopy(atari_env)
'''
# Method `demonstrate_correct_procudure` is to demonstrate the correct procedure to
# use an env to generate trajectories.
# You can check whether your env's design is similar to `DemoEnv`
demonstrate_correct_procedure(DemoEnv)
|