File size: 3,769 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import pytest
import os
import numpy as np
from dizoo.minigrid.envs import MiniGridEnv
from easydict import EasyDict
import copy

# The following two cfg can be tested through TestMiniGridAKTDTnv
config = dict(
    env_id='MiniGrid-AKTDT-13x13-v0',
    flat_obs=True,
)
cfg = EasyDict(copy.deepcopy(config))
cfg.cfg_type = 'MiniGridEnvDict'

config2 = dict(
    env_id='MiniGrid-AKTDT-7x7-1-v0',
    flat_obs=True,
)
cfg2 = EasyDict(copy.deepcopy(config2))
cfg2.cfg_type = 'MiniGridEnvDict'


@pytest.mark.envtest
class TestMiniGridEnv:

    def test_naive(self):
        env = MiniGridEnv(MiniGridEnv.default_config())
        env.seed(314)
        path = './video'
        if not os.path.exists(path):
            os.mkdir(path)
        env.enable_save_replay(path)
        assert env._seed == 314
        obs = env.reset()
        act_val = env.info().act_space.value
        min_val, max_val = act_val['min'], act_val['max']
        for i in range(env._max_step):
            random_action = np.random.randint(min_val, max_val, size=(1, ))
            timestep = env.step(random_action)
            print(timestep)
            print(timestep.obs.max())
            assert isinstance(timestep.obs, np.ndarray)
            assert isinstance(timestep.done, bool)
            assert timestep.obs.shape == (2739, )
            assert timestep.reward.shape == (1, )
            assert timestep.reward >= env.info().rew_space.value['min']
            assert timestep.reward <= env.info().rew_space.value['max']
            if timestep.done:
                env.reset()
        print(env.info())
        env.close()


@pytest.mark.envtest
class TestMiniGridAKTDTnv:

    def test_adtkt_13(self):
        env = MiniGridEnv(cfg2)
        env.seed(314)
        path = './video'
        if not os.path.exists(path):
            os.mkdir(path)
        env.enable_save_replay(path)
        assert env._seed == 314
        obs = env.reset()
        act_val = env.info().act_space.value
        min_val, max_val = act_val['min'], act_val['max']
        for i in range(env._max_step):
            random_action = np.random.randint(min_val, max_val, size=(1, ))
            timestep = env.step(random_action)
            print(timestep)
            print(timestep.obs.max())
            assert isinstance(timestep.obs, np.ndarray)
            assert isinstance(timestep.done, bool)
            assert timestep.obs.shape == (2667, )
            assert timestep.reward.shape == (1, )
            assert timestep.reward >= env.info().rew_space.value['min']
            assert timestep.reward <= env.info().rew_space.value['max']
            if timestep.done:
                env.reset()
        print(env.info())
        env.close()

    def test_adtkt_7(self):
        env = MiniGridEnv(cfg2)
        env.seed(314)
        path = './video'
        if not os.path.exists(path):
            os.mkdir(path)
        env.enable_save_replay(path)
        assert env._seed == 314
        obs = env.reset()
        act_val = env.info().act_space.value
        min_val, max_val = act_val['min'], act_val['max']
        for i in range(env._max_step):
            random_action = np.random.randint(min_val, max_val, size=(1, ))
            timestep = env.step(random_action)
            print(timestep)
            print(timestep.obs.max())
            assert isinstance(timestep.obs, np.ndarray)
            assert isinstance(timestep.done, bool)
            assert timestep.obs.shape == (2619, )
            assert timestep.reward.shape == (1, )
            assert timestep.reward >= env.info().rew_space.value['min']
            assert timestep.reward <= env.info().rew_space.value['max']
            if timestep.done:
                env.reset()
        print(env.info())
        env.close()