File size: 747 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import copy
import torch
from ding.envs.common import EnvElementRunner
from ding.envs.env.base_env import BaseEnv
from .gfootball_reward import GfootballReward
class GfootballRewardRunner(EnvElementRunner):
def _init(self, cfg, *args, **kwargs) -> None:
# set self._core and other state variable
self._core = GfootballReward(cfg)
self._cum_reward = 0.0
def get(self, engine: BaseEnv) -> torch.tensor:
ret = copy.deepcopy(engine._reward_of_action)
self._cum_reward += ret
return self._core._to_agent_processor(ret)
def reset(self) -> None:
self._cum_reward = 0.0
@property
def cum_reward(self) -> torch.tensor:
return torch.FloatTensor([self._cum_reward])
|