File size: 3,769 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import pytest
import os
import numpy as np
from dizoo.minigrid.envs import MiniGridEnv
from easydict import EasyDict
import copy
# The following two cfg can be tested through TestMiniGridAKTDTnv
config = dict(
env_id='MiniGrid-AKTDT-13x13-v0',
flat_obs=True,
)
cfg = EasyDict(copy.deepcopy(config))
cfg.cfg_type = 'MiniGridEnvDict'
config2 = dict(
env_id='MiniGrid-AKTDT-7x7-1-v0',
flat_obs=True,
)
cfg2 = EasyDict(copy.deepcopy(config2))
cfg2.cfg_type = 'MiniGridEnvDict'
@pytest.mark.envtest
class TestMiniGridEnv:
def test_naive(self):
env = MiniGridEnv(MiniGridEnv.default_config())
env.seed(314)
path = './video'
if not os.path.exists(path):
os.mkdir(path)
env.enable_save_replay(path)
assert env._seed == 314
obs = env.reset()
act_val = env.info().act_space.value
min_val, max_val = act_val['min'], act_val['max']
for i in range(env._max_step):
random_action = np.random.randint(min_val, max_val, size=(1, ))
timestep = env.step(random_action)
print(timestep)
print(timestep.obs.max())
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (2739, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.info().rew_space.value['min']
assert timestep.reward <= env.info().rew_space.value['max']
if timestep.done:
env.reset()
print(env.info())
env.close()
@pytest.mark.envtest
class TestMiniGridAKTDTnv:
def test_adtkt_13(self):
env = MiniGridEnv(cfg2)
env.seed(314)
path = './video'
if not os.path.exists(path):
os.mkdir(path)
env.enable_save_replay(path)
assert env._seed == 314
obs = env.reset()
act_val = env.info().act_space.value
min_val, max_val = act_val['min'], act_val['max']
for i in range(env._max_step):
random_action = np.random.randint(min_val, max_val, size=(1, ))
timestep = env.step(random_action)
print(timestep)
print(timestep.obs.max())
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (2667, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.info().rew_space.value['min']
assert timestep.reward <= env.info().rew_space.value['max']
if timestep.done:
env.reset()
print(env.info())
env.close()
def test_adtkt_7(self):
env = MiniGridEnv(cfg2)
env.seed(314)
path = './video'
if not os.path.exists(path):
os.mkdir(path)
env.enable_save_replay(path)
assert env._seed == 314
obs = env.reset()
act_val = env.info().act_space.value
min_val, max_val = act_val['min'], act_val['max']
for i in range(env._max_step):
random_action = np.random.randint(min_val, max_val, size=(1, ))
timestep = env.step(random_action)
print(timestep)
print(timestep.obs.max())
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (2619, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.info().rew_space.value['min']
assert timestep.reward <= env.info().rew_space.value['max']
if timestep.done:
env.reset()
print(env.info())
env.close()
|