genrl / envs /main.py
mazpie's picture
Initial commit
2d9a728
from collections import OrderedDict, deque
from typing import Any, NamedTuple
import os
import dm_env
import numpy as np
from dm_env import StepType, specs
import gym
import torch
class ExtendedTimeStep(NamedTuple):
step_type: Any
reward: Any
discount: Any
observation: Any
action: Any
def first(self):
return self.step_type == StepType.FIRST
def mid(self):
return self.step_type == StepType.MID
def last(self):
return self.step_type == StepType.LAST
def __getitem__(self, attr):
return getattr(self, attr)
class FlattenJacoObservationWrapper(dm_env.Environment):
def __init__(self, env):
self._env = env
self._obs_spec = OrderedDict()
wrapped_obs_spec = env.observation_spec().copy()
if 'front_close' in wrapped_obs_spec:
spec = wrapped_obs_spec['front_close']
# drop batch dim
self._obs_spec['pixels'] = specs.BoundedArray(shape=spec.shape[1:],
dtype=spec.dtype,
minimum=spec.minimum,
maximum=spec.maximum,
name='pixels')
wrapped_obs_spec.pop('front_close')
for key, spec in wrapped_obs_spec.items():
assert spec.dtype == np.float64
assert type(spec) == specs.Array
dim = np.sum(
np.fromiter((int(np.prod(spec.shape))
for spec in wrapped_obs_spec.values()), np.int32))
self._obs_spec['observations'] = specs.Array(shape=(dim,),
dtype=np.float32,
name='observations')
def _transform_observation(self, time_step):
obs = OrderedDict()
if 'front_close' in time_step.observation:
pixels = time_step.observation['front_close']
time_step.observation.pop('front_close')
pixels = np.squeeze(pixels)
obs['pixels'] = pixels
features = []
for feature in time_step.observation.values():
features.append(feature.ravel())
obs['observations'] = np.concatenate(features, axis=0)
return time_step._replace(observation=obs)
def reset(self):
time_step = self._env.reset()
return self._transform_observation(time_step)
def step(self, action):
time_step = self._env.step(action)
return self._transform_observation(time_step)
def observation_spec(self):
return self._obs_spec
def action_spec(self):
return self._env.action_spec()
def __getattr__(self, name):
return getattr(self._env, name)
class ActionRepeatWrapper(dm_env.Environment):
def __init__(self, env, num_repeats):
self._env = env
self._num_repeats = num_repeats
def step(self, action):
reward = 0.0
discount = 1.0
for i in range(self._num_repeats):
time_step = self._env.step(action)
reward += (time_step.reward or 0.0) * discount
discount *= time_step.discount
if time_step.last():
break
return time_step._replace(reward=reward, discount=discount)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._env.action_spec()
def reset(self):
return self._env.reset()
def __getattr__(self, name):
return getattr(self._env, name)
class FramesWrapper(dm_env.Environment):
def __init__(self, env, num_frames=1, pixels_key='pixels'):
self._env = env
self._num_frames = num_frames
self._frames = deque([], maxlen=num_frames)
self._pixels_key = pixels_key
wrapped_obs_spec = env.observation_spec()
assert pixels_key in wrapped_obs_spec
pixels_shape = wrapped_obs_spec[pixels_key].shape
# remove batch dim
if len(pixels_shape) == 4:
pixels_shape = pixels_shape[1:]
self._obs_spec = specs.BoundedArray(shape=np.concatenate(
[[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0),
dtype=np.uint8,
minimum=0,
maximum=255,
name='observation')
def _transform_observation(self, time_step):
assert len(self._frames) == self._num_frames
obs = np.concatenate(list(self._frames), axis=0)
return time_step._replace(observation=obs)
def _extract_pixels(self, time_step):
pixels = time_step.observation[self._pixels_key]
# remove batch dim
if len(pixels.shape) == 4:
pixels = pixels[0]
return pixels.transpose(2, 0, 1).copy()
def reset(self):
time_step = self._env.reset()
pixels = self._extract_pixels(time_step)
for _ in range(self._num_frames):
self._frames.append(pixels)
return self._transform_observation(time_step)
def step(self, action):
time_step = self._env.step(action)
pixels = self._extract_pixels(time_step)
self._frames.append(pixels)
return self._transform_observation(time_step)
def observation_spec(self):
return self._obs_spec
def action_spec(self):
return self._env.action_spec()
def __getattr__(self, name):
return getattr(self._env, name)
class OneHotAction(gym.Wrapper):
def __init__(self, env):
assert isinstance(env.action_space, gym.spaces.Discrete)
super().__init__(env)
self._random = np.random.RandomState()
shape = (self.env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.discrete = True
self.action_space = space
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
if not np.allclose(reference, action):
raise ValueError(f"Invalid one-hot action:\n{action}")
return self.env.step(index)
def reset(self):
return self.env.reset()
def _sample_action(self):
actions = self.env.action_space.n
index = self._random.randint(0, actions)
reference = np.zeros(actions, dtype=np.float32)
reference[index] = 1.0
return reference
class ActionDTypeWrapper(dm_env.Environment):
def __init__(self, env, dtype):
self._env = env
wrapped_action_spec = env.action_spec()
self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
dtype,
wrapped_action_spec.minimum,
wrapped_action_spec.maximum,
'action')
def step(self, action):
action = action.astype(self._env.action_spec().dtype)
return self._env.step(action)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._action_spec
def reset(self):
return self._env.reset()
def __getattr__(self, name):
return getattr(self._env, name)
class ObservationDTypeWrapper(dm_env.Environment):
def __init__(self, env, dtype):
self._env = env
self._dtype = dtype
wrapped_obs_spec = env.observation_spec()['observations']
self._obs_spec = specs.Array(wrapped_obs_spec.shape, dtype,
'observation')
def _transform_observation(self, time_step):
obs = time_step.observation['observations'].astype(self._dtype)
return time_step._replace(observation=obs)
def reset(self):
time_step = self._env.reset()
return self._transform_observation(time_step)
def step(self, action):
time_step = self._env.step(action)
return self._transform_observation(time_step)
def observation_spec(self):
return self._obs_spec
def action_spec(self):
return self._env.action_spec()
def __getattr__(self, name):
return getattr(self._env, name)
class ExtendedTimeStepWrapper(dm_env.Environment):
def __init__(self, env):
self._env = env
def reset(self):
time_step = self._env.reset()
return self._augment_time_step(time_step)
def step(self, action):
time_step = self._env.step(action)
return self._augment_time_step(time_step, action)
def _augment_time_step(self, time_step, action=None):
if action is None:
action_spec = self.action_spec()
action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
return ExtendedTimeStep(observation=time_step.observation,
step_type=time_step.step_type,
action=action,
reward=time_step.reward or 0.0,
discount=time_step.discount or 1.0)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._env.action_spec()
def __getattr__(self, name):
return getattr(self._env, name)
class DMC:
def __init__(self, env):
self._env = env
self._ignored_keys = []
def step(self, action):
time_step = self._env.step(action)
assert time_step.discount in (0, 1)
obs = {
'reward': time_step.reward,
'is_first': False,
'is_last': time_step.last(),
'is_terminal': time_step.discount == 0,
'observation': time_step.observation,
'action' : action,
'discount': time_step.discount
}
return time_step, obs
def reset(self):
time_step = self._env.reset()
obs = {
'reward': 0.0,
'is_first': True,
'is_last': False,
'is_terminal': False,
'observation': time_step.observation,
'action' : np.zeros_like(self.act_space['action'].sample()),
'discount': time_step.discount
}
return time_step, obs
def __getattr__(self, name):
if name == 'obs_space':
obs_spaces = {
'observation': self._env.observation_spec(),
'is_first': gym.spaces.Box(0, 1, (), dtype=bool),
'is_last': gym.spaces.Box(0, 1, (), dtype=bool),
'is_terminal': gym.spaces.Box(0, 1, (), dtype=bool),
}
return obs_spaces
if name == 'act_space':
spec = self._env.action_spec()
action = gym.spaces.Box((spec.minimum)*spec.shape[0], (spec.maximum)*spec.shape[0], shape=spec.shape, dtype=np.float32)
act_space = {'action': action}
return act_space
return getattr(self._env, name)
class OneHotAction(gym.Wrapper):
def __init__(self, env):
assert isinstance(env.action_space, gym.spaces.Discrete)
super().__init__(env)
self._random = np.random.RandomState()
shape = (self.env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.discrete = True
self.action_space = space
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
if not np.allclose(reference, action):
raise ValueError(f"Invalid one-hot action:\n{action}")
return self.env.step(index)
def reset(self):
return self.env.reset()
def _sample_action(self):
actions = self.env.action_space.n
index = self._random.randint(0, actions)
reference = np.zeros(actions, dtype=np.float32)
reference[index] = 1.0
return reference
class KitchenWrapper:
def __init__(
self,
name,
seed=0,
action_repeat=1,
size=(64, 64),
):
import envs.kitchen_extra as kitchen_extra
self._env = {
'microwave' : kitchen_extra.KitchenMicrowaveV0,
'kettle' : kitchen_extra.KitchenKettleV0,
'burner' : kitchen_extra.KitchenBurnerV0,
'light' : kitchen_extra.KitchenLightV0,
'hinge' : kitchen_extra.KitchenHingeV0,
'slide' : kitchen_extra.KitchenSlideV0,
'top_burner' : kitchen_extra.KitchenTopBurnerV0,
}[name]()
self._size = size
self._action_repeat = action_repeat
self._seed = seed
self._eval = False
def eval_mode(self,):
self._env.dense = False
self._eval = True
@property
def obs_space(self):
spaces = {
"observation": gym.spaces.Box(0, 255, (3,) + self._size, dtype=np.uint8),
"is_first": gym.spaces.Box(0, 1, (), dtype=bool),
"is_last": gym.spaces.Box(0, 1, (), dtype=bool),
"is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
"state": self._env.observation_space,
}
return spaces
@property
def act_space(self):
action = self._env.action_space
return {"action": action}
def step(self, action):
# assert np.isfinite(action["action"]).all(), action["action"]
reward = 0.0
for _ in range(self._action_repeat):
state, rew, done, info = self._env.step(action.copy())
reward += rew
obs = {
"reward": reward,
"is_first": False,
"is_last": False, # will be handled by timelimit wrapper
"is_terminal": False, # will be handled by per_episode function
"observation": info['images'].transpose(2, 0, 1).copy(),
"state": state.astype(np.float32),
'action' : action,
'discount' : 1
}
if self._eval:
obs['reward'] = min(obs['reward'], 1)
if obs['reward'] > 0:
obs['is_last'] = True
return dm_env.TimeStep(
step_type=dm_env.StepType.MID if not obs['is_last'] else dm_env.StepType.LAST,
reward=obs['reward'],
discount=1,
observation=obs['observation']), obs
def reset(self,):
state = self._env.reset()
obs = {
"reward": 0.0,
"is_first": True,
"is_last": False,
"is_terminal": False,
"observation": self.get_visual_obs(self._size),
"state": state.astype(np.float32),
'action' : np.zeros_like(self.act_space['action'].sample()),
'discount' : 1
}
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=None,
discount=None,
observation=obs['observation']), obs
def __getattr__(self, name):
if name == 'obs_space':
return self.obs_space
if name == 'act_space':
return self.act_space
return getattr(self._env, name)
def get_visual_obs(self, resolution):
img = self._env.render(resolution=resolution,).transpose(2, 0, 1).copy()
return img
class ViClipWrapper:
def __init__(self, env, hd_rendering=False, device='cuda'):
self._env = env
try:
from tools.genrl_utils import viclip_global_instance
except:
from tools.genrl_utils import ViCLIPGlobalInstance
viclip_global_instance = ViCLIPGlobalInstance()
if not viclip_global_instance._instantiated:
viclip_global_instance.instantiate(device)
self.viclip_model = viclip_global_instance.viclip
self.n_frames = self.viclip_model.n_frames
self.viclip_emb_dim = viclip_global_instance.viclip_emb_dim
self.n_frames = self.viclip_model.n_frames
self.buffer = deque(maxlen=self.n_frames)
# NOTE: these are hardcoded for now, as they are the best settings
self.accumulate = True
self.accumulate_buffer = []
self.anticipate_conv1 = False
self.hd_rendering = hd_rendering
def hd_render(self, obs):
if not self.hd_rendering:
return obs['observation']
if self._env._domain_name in ['mw', 'kitchen', 'mujoco']:
return self.get_visual_obs((224,224,))
else:
render_kwargs = {**getattr(self, '_render_kwargs', {})}
render_kwargs.update({'width' : 224, 'height' : 224})
return self._env.physics.render(**render_kwargs).transpose(2,0,1)
def preprocess(self, x):
return x
def process_accumulate(self, process_at_once=4): # NOTE: this could be varied for increasing FPS, depending on the size of the GPU
self.accumulate = False
x = np.stack(self.accumulate_buffer, axis=0)
# Splitting in chunks
chunks = []
chunk_idxs = list(range(0, x.shape[0] + 1, process_at_once))
if chunk_idxs[-1] != x.shape[0]:
chunk_idxs.append(x.shape[0])
start = 0
for end in chunk_idxs[1:]:
embeds = self.clip_process(x[start:end], bypass=True)
chunks.append(embeds.cpu())
start = end
embeds = torch.cat(chunks, dim=0)
assert embeds.shape[0] == len(self.accumulate_buffer)
self.accumulate = True
self.accumulate_buffer = []
return [*embeds.cpu().numpy()], 'clip_video'
def process_episode(self, obs, process_at_once=8):
self.accumulate = False
sequences = []
for j in range(obs.shape[0] - self.n_frames + 1):
sequences.append(obs[j:j+self.n_frames].copy())
sequences = np.stack(sequences, axis=0)
idx_start = 0
clip_vid = []
for idx_end in range(process_at_once, sequences.shape[0] + process_at_once, process_at_once):
x = sequences[idx_start:idx_end]
with torch.no_grad(): # , torch.cuda.amp.autocast():
x = self.clip_process(x, bypass=True)
clip_vid.append(x)
idx_start = idx_end
if len(clip_vid) == 1: # process all at once
embeds = clip_vid[0]
else:
embeds = torch.cat(clip_vid, dim=0)
pad = torch.zeros( (self.n_frames - 1, *embeds.shape[1:]), device=embeds.device, dtype=embeds.dtype)
embeds = torch.cat([pad, embeds], dim=0)
assert embeds.shape[0] == obs.shape[0], f"Shapes are different {embeds.shape[0]} {obs.shape[0]}"
return embeds.cpu().numpy()
def get_sequence(self,):
return np.expand_dims(np.stack(self.buffer, axis=0), axis=0)
def clip_process(self, x, bypass=False):
if len(self.buffer) == self.n_frames or bypass:
if self.accumulate:
self.accumulate_buffer.append(self.preprocess(x)[0])
return torch.zeros(self.viclip_emb_dim)
with torch.no_grad():
B, n_frames, C, H, W = x.shape
obs = torch.from_numpy(x.copy().reshape(B * n_frames, C, H, W)).to(self.viclip_model.device)
processed_obs = self.viclip_model.preprocess_transf(obs / 255)
reshaped_obs = processed_obs.reshape(B, n_frames, 3,processed_obs.shape[-2],processed_obs.shape[-1])
video_embed = self.viclip_model.get_vid_features(reshaped_obs)
return video_embed.detach()
else:
return torch.zeros(self.viclip_emb_dim)
def step(self, action):
ts, obs = self._env.step(action)
self.buffer.append(self.hd_render(obs))
obs['clip_video'] = self.clip_process(self.get_sequence()).cpu().numpy()
return ts, obs
def reset(self,):
# Important to reset the buffer
self.buffer = deque(maxlen=self.n_frames)
ts, obs = self._env.reset()
self.buffer.append(self.hd_render(obs))
obs['clip_video'] = self.clip_process(self.get_sequence()).cpu().numpy()
return ts, obs
def __getattr__(self, name):
if name == 'obs_space':
space = self._env.obs_space
space['clip_video'] = gym.spaces.Box(-np.inf, np.inf, (self.viclip_emb_dim,), dtype=np.float32)
return space
return getattr(self._env, name)
class TimeLimit:
def __init__(self, env, duration):
self._env = env
self._duration = duration
self._step = None
def __getattr__(self, name):
if name.startswith('__'):
raise AttributeError(name)
return getattr(self._env, name)
def step(self, action):
assert self._step is not None, 'Must reset environment.'
ts, obs = self._env.step(action)
self._step += 1
if self._duration and self._step >= self._duration:
ts = dm_env.TimeStep(dm_env.StepType.LAST, ts.reward, ts.discount, ts.observation)
obs['is_last'] = True
self._step = None
return ts, obs
def reset(self):
self._step = 0
return self._env.reset()
def reset_with_task_id(self, task_id):
self._step = 0
return self._env.reset_with_task_id(task_id)
class ClipActionWrapper:
def __init__(self, env, low=-1.0, high=1.0):
self._env = env
self._low = low
self._high = high
def __getattr__(self, name):
if name.startswith('__'):
raise AttributeError(name)
return getattr(self._env, name)
def step(self, action):
clipped_action = np.clip(action, self._low, self._high)
return self._env.step(clipped_action)
def reset(self):
self._step = 0
return self._env.reset()
def reset_with_task_id(self, task_id):
self._step = 0
return self._env.reset_with_task_id(task_id)
class NormalizeAction:
def __init__(self, env, key='action'):
self._env = env
self._key = key
space = env.act_space[key]
self._mask = np.isfinite(space.low) & np.isfinite(space.high)
self._low = np.where(self._mask, space.low, -1)
self._high = np.where(self._mask, space.high, 1)
def __getattr__(self, name):
if name.startswith('__'):
raise AttributeError(name)
try:
return getattr(self._env, name)
except AttributeError:
raise ValueError(name)
@property
def act_space(self):
low = np.where(self._mask, -np.ones_like(self._low), self._low)
high = np.where(self._mask, np.ones_like(self._low), self._high)
space = gym.spaces.Box(low, high, dtype=np.float32)
return {**self._env.act_space, self._key: space}
def step(self, action):
orig = (action[self._key] + 1) / 2 * (self._high - self._low) + self._low
orig = np.where(self._mask, orig, action[self._key])
return self._env.step({**action, self._key: orig})
def _make_jaco(obs_type, domain, task, action_repeat, seed, img_size,):
import envs.custom_dmc_tasks as cdmc
env = cdmc.make_jaco(task, obs_type, seed, img_size,)
env = ActionDTypeWrapper(env, np.float32)
env = ActionRepeatWrapper(env, action_repeat)
env = FlattenJacoObservationWrapper(env)
env._size = (img_size, img_size)
return env
def _make_dmc(obs_type, domain, task, action_repeat, seed, img_size,):
visualize_reward = False
from dm_control import manipulation, suite
import envs.custom_dmc_tasks as cdmc
if (domain, task) in suite.ALL_TASKS:
env = suite.load(domain,
task,
task_kwargs=dict(random=seed),
environment_kwargs=dict(flat_observation=True),
visualize_reward=visualize_reward)
else:
env = cdmc.make(domain,
task,
task_kwargs=dict(random=seed),
environment_kwargs=dict(flat_observation=True),
visualize_reward=visualize_reward)
env = ActionDTypeWrapper(env, np.float32)
env = ActionRepeatWrapper(env, action_repeat)
if obs_type == 'pixels':
from dm_control.suite.wrappers import pixels
# zoom in camera for quadruped
camera_id = dict(locom_rodent=1,quadruped=2).get(domain, 0)
render_kwargs = dict(height=img_size, width=img_size, camera_id=camera_id)
env = pixels.Wrapper(env,
pixels_only=True,
render_kwargs=render_kwargs)
env._size = (img_size, img_size)
env._camera = camera_id
return env
def make(name, obs_type, action_repeat, seed, img_size=64, viclip_encode=False, clip_hd_rendering=False, device='cuda'):
assert obs_type in ['states', 'pixels']
domain, task = name.split('_', 1)
if domain == 'kitchen':
env = TimeLimit(KitchenWrapper(task, seed=seed, action_repeat=action_repeat, size=(img_size,img_size)), 280 // action_repeat)
else:
os.environ['PYOPENGL_PLATFORM'] = 'egl'
os.environ['MUJOCO_GL'] = 'egl'
domain = dict(cup='ball_in_cup', point='point_mass').get(domain, domain)
make_fn = _make_jaco if domain == 'jaco' else _make_dmc
env = make_fn(obs_type, domain, task, action_repeat, seed, img_size,)
if obs_type == 'pixels':
env = FramesWrapper(env,)
else:
env = ObservationDTypeWrapper(env, np.float32)
from dm_control.suite.wrappers import action_scale
env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0)
env = ExtendedTimeStepWrapper(env)
env = DMC(env)
env._domain_name = domain
if isinstance(env.act_space['action'], gym.spaces.Box):
env = ClipActionWrapper(env,)
if viclip_encode:
env = ViClipWrapper(env, hd_rendering=clip_hd_rendering, device=device)
return env