|
import random |
|
import time |
|
from collections import namedtuple |
|
import pytest |
|
import torch |
|
import numpy as np |
|
from easydict import EasyDict |
|
from functools import partial |
|
import gym |
|
|
|
from ding.envs.env.base_env import BaseEnvTimestep |
|
from ding.envs.env_manager.base_env_manager import EnvState |
|
from ding.envs.env_manager import BaseEnvManager, SyncSubprocessEnvManager, AsyncSubprocessEnvManager |
|
from ding.torch_utils import to_tensor, to_ndarray, to_list |
|
from ding.utils import deep_merge_dicts |
|
|
|
|
|
class FakeEnv(object): |
|
|
|
def __init__(self, cfg): |
|
self._scale = cfg.scale |
|
self._target_time = random.randint(3, 6) * self._scale |
|
self._current_time = 0 |
|
self._name = cfg['name'] |
|
self._id = time.time() |
|
self._stat = None |
|
self._seed = 0 |
|
self._data_count = 0 |
|
self.timeout_flag = False |
|
self._launched = False |
|
self._state = EnvState.INIT |
|
self._dead_once = False |
|
self.observation_space = gym.spaces.Box( |
|
low=np.array([-1.0, -1.0, -8.0]), high=np.array([1.0, 1.0, 8.0]), shape=(3, ), dtype=np.float32 |
|
) |
|
self.action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1, ), dtype=np.float32) |
|
self.reward_space = gym.spaces.Box( |
|
low=-1 * (3.14 * 3.14 + 0.1 * 8 * 8 + 0.001 * 2 * 2), high=0.0, shape=(1, ), dtype=np.float32 |
|
) |
|
|
|
def reset(self, stat=None): |
|
if isinstance(stat, str) and stat == 'error': |
|
self.dead() |
|
if isinstance(stat, str) and stat == 'error_once': |
|
|
|
if self._dead_once: |
|
self._dead_once = False |
|
self.dead() |
|
else: |
|
self._dead_once = True |
|
if isinstance(stat, str) and stat == "wait": |
|
if self.timeout_flag: |
|
time.sleep(5) |
|
if isinstance(stat, str) and stat == "block": |
|
self.block() |
|
|
|
self._launched = True |
|
self._current_time = 0 |
|
self._stat = stat |
|
self._state = EnvState.RUN |
|
return to_ndarray(torch.randn(3)) |
|
|
|
def step(self, action): |
|
assert self._launched |
|
assert not self._state == EnvState.ERROR |
|
self.timeout_flag = True |
|
if isinstance(action, str) and action == 'error': |
|
self.dead() |
|
if isinstance(action, str) and action == 'catched_error': |
|
return BaseEnvTimestep(None, None, True, {'abnormal': True}) |
|
if isinstance(action, str) and action == "wait": |
|
if self.timeout_flag: |
|
time.sleep(3) |
|
if isinstance(action, str) and action == 'block': |
|
self.block() |
|
obs = to_ndarray(torch.randn(3)) |
|
reward = to_ndarray(torch.randint(0, 2, size=[1]).numpy()) |
|
done = self._current_time >= self._target_time |
|
if done: |
|
self._state = EnvState.DONE |
|
simulation_time = random.uniform(0.5, 1) * self._scale |
|
info = {'name': self._name, 'time': simulation_time, 'tgt': self._target_time, 'cur': self._current_time} |
|
time.sleep(simulation_time) |
|
self._current_time += simulation_time |
|
self._data_count += 1 |
|
return BaseEnvTimestep(obs, reward, done, info) |
|
|
|
def dead(self): |
|
self._state = EnvState.ERROR |
|
raise RuntimeError("env error, current time {}".format(self._current_time)) |
|
|
|
def block(self): |
|
self._state = EnvState.ERROR |
|
time.sleep(1000) |
|
|
|
def close(self): |
|
self._launched = False |
|
self._state = EnvState.INIT |
|
|
|
def seed(self, seed): |
|
self._seed = seed |
|
|
|
@property |
|
def name(self): |
|
return self._name |
|
|
|
@property |
|
def time_id(self): |
|
return self._id |
|
|
|
def user_defined(self): |
|
pass |
|
|
|
def __repr__(self): |
|
return self._name |
|
|
|
|
|
class FakeAsyncEnv(FakeEnv): |
|
|
|
def reset(self, stat=None): |
|
super().reset(stat) |
|
time.sleep(random.randint(1, 3) * self._scale) |
|
return to_ndarray(torch.randn(3)) |
|
|
|
|
|
class FakeGymEnv(FakeEnv): |
|
|
|
def __init__(self, cfg): |
|
super().__init__(cfg) |
|
self.metadata = "fake metadata" |
|
self.action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(4, ), dtype=np.float32) |
|
|
|
def random_action(self) -> np.ndarray: |
|
random_action = self.action_space.sample() |
|
if isinstance(random_action, np.ndarray): |
|
pass |
|
elif isinstance(random_action, int): |
|
random_action = to_ndarray([random_action], dtype=np.int64) |
|
elif isinstance(random_action, dict): |
|
random_action = to_ndarray(random_action) |
|
else: |
|
raise TypeError( |
|
'`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format( |
|
type(random_action), random_action |
|
) |
|
) |
|
return random_action |
|
|
|
|
|
class FakeModel(object): |
|
|
|
def forward(self, obs): |
|
if random.random() > 0.5: |
|
return {k: [] for k in obs} |
|
else: |
|
env_num = len(obs) |
|
exec_env = random.randint(1, env_num + 1) |
|
keys = list(obs.keys())[:exec_env] |
|
return {k: [] for k in keys} |
|
|
|
|
|
@pytest.fixture(scope='class') |
|
def setup_model_type(): |
|
return FakeModel |
|
|
|
|
|
def get_base_manager_cfg(env_num=3): |
|
manager_cfg = { |
|
'env_cfg': [{ |
|
'name': 'name{}'.format(i), |
|
'scale': 1.0, |
|
} for i in range(env_num)], |
|
'episode_num': 2, |
|
'reset_timeout': 10, |
|
'step_timeout': 8, |
|
'max_retry': 5, |
|
} |
|
return EasyDict(manager_cfg) |
|
|
|
|
|
def get_subprecess_manager_cfg(env_num=3): |
|
manager_cfg = { |
|
'env_cfg': [{ |
|
'name': 'name{}'.format(i), |
|
'scale': 1.0, |
|
} for i in range(env_num)], |
|
'episode_num': 2, |
|
|
|
|
|
'connect_timeout': 8, |
|
'step_timeout': 5, |
|
'max_retry': 2, |
|
} |
|
return EasyDict(manager_cfg) |
|
|
|
|
|
def get_gym_vector_manager_cfg(env_num=3): |
|
manager_cfg = { |
|
'env_cfg': [{ |
|
'name': 'name{}'.format(i), |
|
} for i in range(env_num)], |
|
'episode_num': 2, |
|
'connect_timeout': 8, |
|
'step_timeout': 5, |
|
'max_retry': 2, |
|
'share_memory': True |
|
} |
|
return EasyDict(manager_cfg) |
|
|
|
|
|
@pytest.fixture(scope='function') |
|
def setup_base_manager_cfg(): |
|
manager_cfg = get_base_manager_cfg(3) |
|
env_cfg = manager_cfg.pop('env_cfg') |
|
manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg] |
|
return deep_merge_dicts(BaseEnvManager.default_config(), EasyDict(manager_cfg)) |
|
|
|
|
|
@pytest.fixture(scope='function') |
|
def setup_fast_base_manager_cfg(): |
|
manager_cfg = get_base_manager_cfg(3) |
|
env_cfg = manager_cfg.pop('env_cfg') |
|
for e in env_cfg: |
|
e['scale'] = 0.1 |
|
manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg] |
|
return deep_merge_dicts(BaseEnvManager.default_config(), EasyDict(manager_cfg)) |
|
|
|
|
|
@pytest.fixture(scope='function') |
|
def setup_sync_manager_cfg(): |
|
manager_cfg = get_subprecess_manager_cfg(3) |
|
env_cfg = manager_cfg.pop('env_cfg') |
|
|
|
manager_cfg['shared_memory'] = False |
|
manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg] |
|
return deep_merge_dicts(SyncSubprocessEnvManager.default_config(), EasyDict(manager_cfg)) |
|
|
|
|
|
@pytest.fixture(scope='function') |
|
def setup_async_manager_cfg(): |
|
manager_cfg = get_subprecess_manager_cfg(3) |
|
env_cfg = manager_cfg.pop('env_cfg') |
|
manager_cfg['env_fn'] = [partial(FakeAsyncEnv, cfg=c) for c in env_cfg] |
|
manager_cfg['shared_memory'] = False |
|
return deep_merge_dicts(AsyncSubprocessEnvManager.default_config(), EasyDict(manager_cfg)) |
|
|
|
|
|
@pytest.fixture(scope='function') |
|
def setup_gym_vector_manager_cfg(): |
|
manager_cfg = get_subprecess_manager_cfg(3) |
|
env_cfg = manager_cfg.pop('env_cfg') |
|
manager_cfg['env_fn'] = [partial(FakeGymEnv, cfg=c) for c in env_cfg] |
|
manager_cfg['shared_memory'] = False |
|
return EasyDict(manager_cfg) |
|
|