File size: 2,953 Bytes
b05c680 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gym
import numpy as np
from typing import Any, Dict, Tuple, Union
ObsType = Union[np.ndarray, dict]
ActType = Union[int, float, np.ndarray, dict]
class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env: gym.Env, training: bool = True, noop_act: int = 0) -> None:
super().__init__(env)
self.training = training
self.noop_act = noop_act
self.life_done_continue = False
self.lives = 0
def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
obs, rew, done, info = self.env.step(action)
new_lives = self.env.unwrapped.ale.lives()
self.life_done_continue = new_lives != self.lives and not done
# Only if training should life-end be marked as done
if self.training and 0 < new_lives < self.lives:
done = True
self.lives = new_lives
return obs, rew, done, info
def reset(self, **kwargs) -> ObsType:
# If life_done_continue (but not game over), then a reset should just allow the
# game to progress to the next life.
if self.training and self.life_done_continue:
obs, _, _, _ = self.env.step(self.noop_act)
else:
obs = self.env.reset(**kwargs)
self.lives = self.env.unwrapped.ale.lives()
return obs
class FireOnLifeStarttEnv(gym.Wrapper):
def __init__(self, env: gym.Env, fire_act: int = 1) -> None:
super().__init__(env)
self.fire_act = fire_act
action_meanings = env.unwrapped.get_action_meanings()
assert action_meanings[fire_act] == "FIRE"
assert len(action_meanings) >= 3
self.lives = 0
self.fire_on_next_step = True
def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
if self.fire_on_next_step:
action = self.fire_act
self.fire_on_next_step = False
obs, rew, done, info = self.env.step(action)
new_lives = self.env.unwrapped.ale.lives()
if 0 < new_lives < self.lives and not done:
self.fire_on_next_step = True
self.lives = new_lives
return obs, rew, done, info
def reset(self, **kwargs) -> ObsType:
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(self.fire_act)
if done:
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset(**kwargs)
self.fire_on_next_step = False
return obs
class ClipRewardEnv(gym.Wrapper):
def __init__(self, env: gym.Env, training: bool = True) -> None:
super().__init__(env)
self.training = training
def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
obs, rew, done, info = self.env.step(action)
if self.training:
info["unclipped_reward"] = rew
rew = np.sign(rew)
return obs, rew, done, info
|