gomoku / DI-engine /dizoo /bsuite /envs /test_bsuite_env.py
zjowowen's picture
init space
079c32c
from time import time
import pytest
import numpy as np
from easydict import EasyDict
from dizoo.bsuite.envs import BSuiteEnv
@pytest.mark.envtest
class TestBSuiteEnv:
def test_memory_len(self):
cfg = {'env_id': 'memory_len/0'}
cfg = EasyDict(cfg)
memory_len_env = BSuiteEnv(cfg)
memory_len_env.seed(0)
obs = memory_len_env.reset()
assert obs.shape == (3, )
while True:
random_action = memory_len_env.random_action()
timestep = memory_len_env.step(random_action)
assert timestep.obs.shape == (3, )
assert timestep.reward.shape == (1, )
if timestep.done:
assert 'eval_episode_return' in timestep.info, timestep.info
break
memory_len_env.close()
def test_cartpole_swingup(self):
cfg = {'env_id': 'cartpole_swingup/0'}
cfg = EasyDict(cfg)
bandit_noise_env = BSuiteEnv(cfg)
bandit_noise_env.seed(0)
obs = bandit_noise_env.reset()
assert obs.shape == (8, )
while True:
random_action = bandit_noise_env.random_action()
timestep = bandit_noise_env.step(random_action)
assert timestep.obs.shape == (8, )
assert timestep.reward.shape == (1, )
if timestep.done:
assert 'eval_episode_return' in timestep.info, timestep.info
break
bandit_noise_env.close()