|
import gfootball |
|
import gfootball.env as football_env |
|
|
|
import copy |
|
from collections import namedtuple |
|
from typing import List, Any, Optional |
|
|
|
import numpy as np |
|
from ding.envs import BaseEnv |
|
from ding.utils import ENV_REGISTRY |
|
from .action.gfootball_action_runner import GfootballRawActionRunner |
|
from .obs.gfootball_obs_runner import GfootballObsRunner |
|
from .reward.gfootball_reward_runner import GfootballRewardRunner |
|
import gym |
|
from ding.torch_utils import to_ndarray, to_list |
|
import os |
|
from matplotlib import animation |
|
import matplotlib.pyplot as plt |
|
from ding.envs import ObsPlusPrevActRewWrapper |
|
|
|
|
|
@ENV_REGISTRY.register('gfootball') |
|
class GfootballEnv(BaseEnv): |
|
timestep = namedtuple('GfootballTimestep', ['obs', 'reward', 'done', 'info']) |
|
|
|
info_template = namedtuple('GFootballEnvInfo', ['obs_space', 'act_space', 'rew_space']) |
|
|
|
def __init__(self, cfg): |
|
self._cfg = cfg |
|
self._action_helper = GfootballRawActionRunner(cfg) |
|
self._reward_helper = GfootballRewardRunner(cfg) |
|
self._obs_helper = GfootballObsRunner(cfg) |
|
self.save_replay = cfg.get("save_replay", False) |
|
self._launch_env_flag = False |
|
self._launch_env() |
|
self.env_name = self._cfg.env_name |
|
self._save_replay_gif = self._cfg.save_replay_gif |
|
|
|
def _launch_env(self, gui=False): |
|
|
|
self._env = football_env.create_environment( |
|
|
|
env_name=self._cfg.env_name, |
|
representation='raw', |
|
stacked=False, |
|
logdir='./tmp/football', |
|
write_goal_dumps=False, |
|
write_full_episode_dumps=self.save_replay, |
|
write_video=self.save_replay, |
|
render=False |
|
) |
|
self._launch_env_flag = True |
|
|
|
def reset(self) -> dict: |
|
if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward: |
|
|
|
self.prev_action = -1 |
|
self.prev_reward_extrinsic = 0 |
|
|
|
if self._save_replay_gif: |
|
self._frames = [] |
|
if not self._launch_env_flag: |
|
self._launch_env() |
|
self._football_obs = self._env.reset()[0] |
|
self._reward_helper.reset() |
|
self._obs_helper.reset() |
|
self._action_helper.reset() |
|
self._observation_space = gym.spaces.Dict( |
|
{ |
|
'match': gym.spaces.Dict( |
|
{ |
|
k: gym.spaces.Discrete(v['max']) if v['dinfo'] == 'one-hot' else |
|
gym.spaces.Box(low=np.array(v['min']), high=np.array(v['max']), dtype=np.float32) |
|
for k, v in self._obs_helper.info['match'].value.items() |
|
} |
|
), |
|
'player': gym.spaces.Dict( |
|
{ |
|
k: gym.spaces.Discrete(v['max']) if v['dinfo'] == 'one-hot' else |
|
gym.spaces.Box(low=np.array(v['min']), high=np.array(v['max']), dtype=np.float32) |
|
for k, v in self._obs_helper.info['player'].value['players'].items() |
|
} |
|
) |
|
} |
|
) |
|
self._action_space = gym.spaces.Discrete(self._action_helper.info.shape[0]) |
|
self._reward_space = gym.spaces.Box( |
|
low=self._reward_helper.info.value['min'], |
|
high=self._reward_helper.info.value['max'], |
|
shape=self._reward_helper.info.shape, |
|
dtype=np.float32 |
|
) |
|
|
|
self.obs = self._obs_helper.get(self) |
|
|
|
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: |
|
np_seed = 100 * np.random.randint(1, 1000) |
|
self._env.seed(self._seed + np_seed) |
|
elif hasattr(self, '_seed'): |
|
self._env.seed(self._seed) |
|
if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward: |
|
|
|
return { |
|
'obs': { |
|
'processed_obs': self.obs, |
|
'raw_obs': self._football_obs |
|
}, |
|
'prev_action': self.prev_action, |
|
'prev_reward_extrinsic': self.prev_reward_extrinsic |
|
} |
|
else: |
|
return {'processed_obs': self.obs, 'raw_obs': self._football_obs} |
|
|
|
def step(self, action: np.array) -> 'GfootballEnv.timestep': |
|
assert self._launch_env_flag |
|
self.agent_action = action |
|
action = action.item() |
|
|
|
if self._save_replay_gif: |
|
self._frames.append(self._env.render(mode='rgb_array')) |
|
self._football_obs, self._reward_of_action, self._is_done, self._info = self._env.step(action) |
|
self._football_obs = self._football_obs[0] |
|
self.action = self._action_helper.get(self) |
|
self.reward = self._reward_helper.get(self) |
|
self.obs = self._obs_helper.get(self) |
|
|
|
info = {'cum_reward': self._reward_helper.cum_reward} |
|
if self._is_done: |
|
info['eval_episode_return'] = to_ndarray(self._reward_helper.cum_reward) |
|
if self._save_replay_gif: |
|
path = os.path.join( |
|
self._replay_path, '{}_episode_{}.gif'.format(self.env_name, self._save_replay_gif_count) |
|
) |
|
self.display_frames_as_gif(self._frames, path) |
|
self._save_replay_gif_count += 1 |
|
print(f'save one episode replay_gif in {path}') |
|
|
|
self.reward = to_ndarray(self.reward) |
|
|
|
if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward: |
|
|
|
self.prev_action = action |
|
self.prev_reward_extrinsic = self.reward |
|
obs = { |
|
'obs': { |
|
'processed_obs': self.obs, |
|
'raw_obs': self._football_obs |
|
}, |
|
'prev_action': self.prev_action, |
|
'prev_reward_extrinsic': self.prev_reward_extrinsic |
|
} |
|
else: |
|
obs = {'processed_obs': self.obs, 'raw_obs': self._football_obs} |
|
|
|
return GfootballEnv.timestep(obs, reward=self.reward, done=self._is_done, info=info) |
|
|
|
def seed(self, seed: int, dynamic_seed: bool = True) -> None: |
|
self._seed = seed |
|
self._dynamic_seed = dynamic_seed |
|
np.random.seed(self._seed) |
|
|
|
def close(self) -> None: |
|
self._env.close() |
|
|
|
def __repr__(self) -> str: |
|
return 'GfootballEnv:\n\ |
|
\tobservation[{}]\n\ |
|
\taction[{}]\n\ |
|
\treward[{}]\n'.format(repr(self._obs_helper), repr(self._action_helper), repr(self._reward_helper)) |
|
|
|
def info(self) -> 'GfootballEnv.info': |
|
info_data = { |
|
'obs_space': self._obs_helper.info, |
|
'act_space': self._action_helper.info, |
|
'rew_space': self._reward_helper.info, |
|
} |
|
return GfootballEnv.info_template(**info_data) |
|
|
|
@staticmethod |
|
def create_collector_env_cfg(cfg: dict) -> List[dict]: |
|
collector_env_num = cfg.pop('collector_env_num', 1) |
|
cfg = copy.deepcopy(cfg) |
|
cfg.save_replay = False |
|
return [cfg for _ in range(collector_env_num)] |
|
|
|
@staticmethod |
|
def create_evaluator_env_cfg(cfg: dict) -> List[dict]: |
|
evaluator_env_num = cfg.pop('evaluator_env_num', 1) |
|
cfg = copy.deepcopy(cfg) |
|
cfg.save_replay = True |
|
return [cfg for _ in range(evaluator_env_num)] |
|
|
|
def random_action(self) -> np.ndarray: |
|
random_action = self.action_space.sample() |
|
random_action = to_ndarray([random_action], dtype=np.int64) |
|
return random_action |
|
|
|
@property |
|
def observation_space(self) -> gym.spaces.Space: |
|
return self._observation_space |
|
|
|
@property |
|
def action_space(self) -> gym.spaces.Space: |
|
return self._action_space |
|
|
|
@property |
|
def reward_space(self) -> gym.spaces.Space: |
|
return self._reward_space |
|
|
|
def enable_save_replay(self, replay_path: Optional[str] = None) -> None: |
|
if replay_path is None: |
|
replay_path = './video' |
|
self._save_replay_gif = True |
|
self._replay_path = replay_path |
|
self._save_replay_gif_count = 0 |
|
|
|
@staticmethod |
|
def display_frames_as_gif(frames: list, path: str) -> None: |
|
patch = plt.imshow(frames[0]) |
|
plt.axis('off') |
|
|
|
def animate(i): |
|
patch.set_data(frames[i]) |
|
|
|
anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=5) |
|
anim.save(path, writer='imagemagick', fps=20) |
|
|
|
|
|
GfootballTimestep = GfootballEnv.timestep |
|
|