from time import time from easydict import EasyDict import pytest import numpy as np from dizoo.overcooked.envs import OvercookEnv, OvercookGameEnv @pytest.mark.envtest class TestOvercooked: @pytest.mark.parametrize("action_mask", [True, False]) def test_overcook(self, action_mask): num_agent = 2 sum_rew = 0.0 env = OvercookEnv(EasyDict({'concat_obs': True, 'action_mask': action_mask})) obs = env.reset() for _ in range(env._horizon): action = env.random_action() timestep = env.step(action) obs = timestep.obs if action_mask: for k, v in obs.items(): if k not in ['agent_state', 'action_mask']: assert False assert v.shape == env.observation_space[k].shape else: assert obs.shape == env.observation_space.shape assert timestep.done sum_rew += timestep.info['eval_episode_return'][0] print("sum reward is:", sum_rew) @pytest.mark.parametrize("concat_obs", [True, False]) def test_overcook_game(self, concat_obs): env = OvercookGameEnv(EasyDict({'concat_obs': concat_obs})) print('observation space: {}'.format(env.observation_space.shape)) obs = env.reset() for _ in range(env._horizon): action = env.random_action() timestep = env.step(action) obs = timestep.obs assert obs.shape == env.observation_space.shape assert timestep.done print("agent 0 sum reward is:", timestep.info[0]['eval_episode_return']) print("agent 1 sum reward is:", timestep.info[1]['eval_episode_return'])