ppo-procgen-coinrun-easy / wrappers /video_compat_wrapper.py
sgoodfriend's picture
PPO playing procgen-coinrun-easy from https://github.com/sgoodfriend/rl-algo-impls/tree/21ee1ab96a186676e5ed2f8c3185902f7c7bca7a
a9b202e
raw
history blame contribute delete
380 Bytes
import gym
import numpy as np
class VideoCompatWrapper(gym.Wrapper):
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
def render(self, mode="human", **kwargs):
r = super().render(mode=mode, **kwargs)
if mode == "rgb_array" and isinstance(r, np.ndarray) and r.dtype != np.uint8:
r = r.astype(np.uint8)
return r