|
from time import time |
|
import pytest |
|
import numpy as np |
|
from easydict import EasyDict |
|
from dizoo.bsuite.envs import BSuiteEnv |
|
|
|
|
|
@pytest.mark.envtest |
|
class TestBSuiteEnv: |
|
|
|
def test_memory_len(self): |
|
cfg = {'env_id': 'memory_len/0'} |
|
cfg = EasyDict(cfg) |
|
memory_len_env = BSuiteEnv(cfg) |
|
memory_len_env.seed(0) |
|
obs = memory_len_env.reset() |
|
assert obs.shape == (3, ) |
|
while True: |
|
random_action = memory_len_env.random_action() |
|
timestep = memory_len_env.step(random_action) |
|
assert timestep.obs.shape == (3, ) |
|
assert timestep.reward.shape == (1, ) |
|
if timestep.done: |
|
assert 'eval_episode_return' in timestep.info, timestep.info |
|
break |
|
memory_len_env.close() |
|
|
|
def test_cartpole_swingup(self): |
|
cfg = {'env_id': 'cartpole_swingup/0'} |
|
cfg = EasyDict(cfg) |
|
bandit_noise_env = BSuiteEnv(cfg) |
|
bandit_noise_env.seed(0) |
|
obs = bandit_noise_env.reset() |
|
assert obs.shape == (8, ) |
|
while True: |
|
random_action = bandit_noise_env.random_action() |
|
timestep = bandit_noise_env.step(random_action) |
|
assert timestep.obs.shape == (8, ) |
|
assert timestep.reward.shape == (1, ) |
|
if timestep.done: |
|
assert 'eval_episode_return' in timestep.info, timestep.info |
|
break |
|
bandit_noise_env.close() |
|
|