PPO playing impala-BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/e47a44c4d891f48885af0b1605b30d19fc67b5af
b05c680
import gym | |
import gym.spaces.dict | |
import numpy as np | |
from gym import ObservationWrapper | |
class GetRgbObservation(ObservationWrapper): | |
def __init__(self, env) -> None: | |
super().__init__(env) | |
assert isinstance(env.observation_space, gym.spaces.dict.Dict) | |
self.observation_space = env.observation_space["rgb"] # type: ignore | |
if getattr(env, "is_vector_env"): | |
self.single_observation_space = env.single_observation_space["rgb"] # type: ignore | |
def observation(self, observation: gym.spaces.dict.Dict) -> np.ndarray: | |
return observation["rgb"] # type: ignore | |