File size: 461 Bytes
8bf4dee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import gym
import numpy as np

from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper


class VideoCompatWrapper(VecotarableWrapper):
    def __init__(self, env: gym.Env) -> None:
        super().__init__(env)

    def render(self, mode="human", **kwargs):
        r = super().render(mode=mode, **kwargs)
        if mode == "rgb_array" and isinstance(r, np.ndarray) and r.dtype != np.uint8:
            r = r.astype(np.uint8)
        return r