|
import pytest |
|
import numpy as np |
|
from easydict import EasyDict |
|
|
|
from dizoo.smac.envs import SMACEnv |
|
|
|
MOVE_EAST = 4 |
|
MOVE_WEST = 5 |
|
|
|
|
|
def automation(env, n_agents): |
|
actions = {"me": [], "opponent": []} |
|
for agent_id in range(n_agents): |
|
avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=False) |
|
avail_actions_ind = np.nonzero(avail_actions)[0] |
|
action = np.random.choice(avail_actions_ind) |
|
if avail_actions[0] != 0: |
|
action = 0 |
|
elif len(np.nonzero(avail_actions[6:])[0]) == 0: |
|
if avail_actions[MOVE_EAST] != 0: |
|
action = MOVE_EAST |
|
else: |
|
action = np.random.choice(avail_actions_ind) |
|
else: |
|
action = np.random.choice(avail_actions_ind) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
actions["me"].append(action) |
|
print('ava', avail_actions, action) |
|
for agent_id in range(n_agents): |
|
avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=True) |
|
avail_actions_ind = np.nonzero(avail_actions)[0] |
|
action = np.random.choice(avail_actions_ind) |
|
if MOVE_EAST in avail_actions_ind: |
|
action = MOVE_EAST |
|
|
|
if sum(avail_actions[6:]) > 0: |
|
|
|
action = max(avail_actions_ind) |
|
actions["opponent"].append(action) |
|
return actions |
|
|
|
|
|
def random_policy(env, n_agents): |
|
actions = {"me": [], "opponent": []} |
|
for agent_id in range(n_agents): |
|
avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=False) |
|
avail_actions_ind = np.nonzero(avail_actions)[0] |
|
action = np.random.choice(avail_actions_ind) |
|
actions["me"].append(action) |
|
for agent_id in range(n_agents): |
|
avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=True) |
|
avail_actions_ind = np.nonzero(avail_actions)[0] |
|
|
|
action = np.random.choice(avail_actions_ind) |
|
actions["opponent"].append(action) |
|
return actions |
|
|
|
|
|
def fix_policy(env, n_agents, me=0, opponent=0): |
|
actions = {"me": [], "opponent": []} |
|
for agent_id in range(n_agents): |
|
avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=False) |
|
avail_actions_ind = np.nonzero(avail_actions)[0] |
|
action = me |
|
if action not in avail_actions_ind: |
|
action = avail_actions_ind[0] |
|
actions["me"].append(action) |
|
|
|
for agent_id in range(n_agents): |
|
avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=True) |
|
avail_actions_ind = np.nonzero(avail_actions)[0] |
|
action = opponent |
|
if action not in avail_actions_ind: |
|
action = avail_actions_ind[0] |
|
actions["opponent"].append(action) |
|
return actions |
|
|
|
|
|
def main(policy, map_name="3m", two_player=False): |
|
cfg = EasyDict({'two_player': two_player, 'map_name': map_name, 'save_replay_episodes': None, 'obs_alone': True}) |
|
env = SMACEnv(cfg) |
|
if map_name == "3s5z": |
|
n_agents = 8 |
|
elif map_name == "3m": |
|
n_agents = 3 |
|
elif map_name == "infestor_viper": |
|
n_agents = 2 |
|
else: |
|
raise ValueError(f"invalid type: {map_name}") |
|
n_episodes = 20 |
|
me_win = 0 |
|
draw = 0 |
|
op_win = 0 |
|
|
|
for e in range(n_episodes): |
|
print("Now reset the environment for {} episode.".format(e)) |
|
env.reset() |
|
print('reset over') |
|
terminated = False |
|
episode_return_me = 0 |
|
episode_return_op = 0 |
|
|
|
env_info = env.info() |
|
print('begin new episode') |
|
while not terminated: |
|
actions = policy(env, n_agents) |
|
if not two_player: |
|
actions = actions["me"] |
|
t = env.step(actions) |
|
obs, reward, terminated, infos = t.obs, t.reward, t.done, t.info |
|
assert set(obs.keys()) == set( |
|
['agent_state', 'global_state', 'action_mask', 'agent_alone_state', 'agent_alone_padding_state'] |
|
) |
|
assert isinstance(obs['agent_state'], np.ndarray) |
|
assert obs['agent_state'].shape == env_info.obs_space.shape['agent_state'] |
|
assert isinstance(obs['agent_alone_state'], np.ndarray) |
|
assert obs['agent_alone_state'].shape == env_info.obs_space.shape['agent_alone_state'] |
|
assert isinstance(obs['global_state'], np.ndarray) |
|
assert obs['global_state'].shape == env_info.obs_space.shape['global_state'] |
|
assert isinstance(reward, np.ndarray) |
|
assert reward.shape == (1, ) |
|
print('reward', reward) |
|
assert isinstance(terminated, bool) |
|
episode_return_me += reward["me"] if two_player else reward |
|
episode_return_op += reward["opponent"] if two_player else 0 |
|
terminated = terminated["me"] if two_player else terminated |
|
|
|
if two_player: |
|
me_win += int(infos["me"]["battle_won"]) |
|
op_win += int(infos["opponent"]["battle_won"]) |
|
draw += int(infos["draw"]) |
|
else: |
|
me_win += int(infos["battle_won"]) |
|
op_win += int(infos["battle_lost"]) |
|
draw += int(infos["draw"]) |
|
|
|
print( |
|
"Total return in episode {} = {} (me), {} (opponent). Me win {}, Draw {}, Opponent win {}, total {}." |
|
"".format(e, episode_return_me, episode_return_op, me_win, draw, op_win, e + 1) |
|
) |
|
|
|
env.close() |
|
|
|
|
|
@pytest.mark.env_test |
|
def test_automation(): |
|
|
|
main(automation, map_name="infestor_viper", two_player=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
test_automation() |
|
|