zjowowen's picture
init space
079c32c
from typing import List, Dict, Tuple
from ditk import logging
from copy import deepcopy
from easydict import EasyDict
from torch.utils.data import Dataset
from dataclasses import dataclass
import pickle
import easydict
import torch
import numpy as np
from ding.utils.bfs_helper import get_vi_sequence
from ding.utils import DATASET_REGISTRY, import_module, DatasetNormalizer
from ding.rl_utils import discount_cumsum
@dataclass
class DatasetStatistics:
"""
Overview:
Dataset statistics.
"""
mean: np.ndarray # obs
std: np.ndarray # obs
action_bounds: np.ndarray
@DATASET_REGISTRY.register('naive')
class NaiveRLDataset(Dataset):
"""
Overview:
Naive RL dataset, which is used for offline RL algorithms.
Interfaces:
``__init__``, ``__len__``, ``__getitem__``
"""
def __init__(self, cfg) -> None:
"""
Overview:
Initialization method.
Arguments:
- cfg (:obj:`dict`): Config dict.
"""
assert type(cfg) in [str, EasyDict], "invalid cfg type: {}".format(type(cfg))
if isinstance(cfg, EasyDict):
self._data_path = cfg.policy.collect.data_path
elif isinstance(cfg, str):
self._data_path = cfg
with open(self._data_path, 'rb') as f:
self._data: List[Dict[str, torch.Tensor]] = pickle.load(f)
def __len__(self) -> int:
"""
Overview:
Get the length of the dataset.
"""
return len(self._data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Overview:
Get the item of the dataset.
"""
return self._data[idx]
@DATASET_REGISTRY.register('d4rl')
class D4RLDataset(Dataset):
"""
Overview:
D4RL dataset, which is used for offline RL algorithms.
Interfaces:
``__init__``, ``__len__``, ``__getitem__``
Properties:
- mean (:obj:`np.ndarray`): Mean of the dataset.
- std (:obj:`np.ndarray`): Std of the dataset.
- action_bounds (:obj:`np.ndarray`): Action bounds of the dataset.
- statistics (:obj:`dict`): Statistics of the dataset.
"""
def __init__(self, cfg: dict) -> None:
"""
Overview:
Initialization method.
Arguments:
- cfg (:obj:`dict`): Config dict.
"""
import gym
try:
import d4rl # register d4rl enviroments with open ai gym
except ImportError:
import sys
logging.warning("not found d4rl env, please install it, refer to https://github.com/rail-berkeley/d4rl")
sys.exit(1)
# Init parameters
data_path = cfg.policy.collect.get('data_path', None)
env_id = cfg.env.env_id
# Create the environment
if data_path:
d4rl.set_dataset_path(data_path)
env = gym.make(env_id)
dataset = d4rl.qlearning_dataset(env)
self._cal_statistics(dataset, env)
try:
if cfg.env.norm_obs.use_norm and cfg.env.norm_obs.offline_stats.use_offline_stats:
dataset = self._normalize_states(dataset)
except (KeyError, AttributeError):
# do not normalize
pass
self._data = []
self._load_d4rl(dataset)
@property
def data(self) -> List:
return self._data
def __len__(self) -> int:
"""
Overview:
Get the length of the dataset.
"""
return len(self._data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Overview:
Get the item of the dataset.
"""
return self._data[idx]
def _load_d4rl(self, dataset: Dict[str, np.ndarray]) -> None:
"""
Overview:
Load the d4rl dataset.
Arguments:
- dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset.
"""
for i in range(len(dataset['observations'])):
trans_data = {}
trans_data['obs'] = torch.from_numpy(dataset['observations'][i])
trans_data['next_obs'] = torch.from_numpy(dataset['next_observations'][i])
trans_data['action'] = torch.from_numpy(dataset['actions'][i])
trans_data['reward'] = torch.tensor(dataset['rewards'][i])
trans_data['done'] = dataset['terminals'][i]
self._data.append(trans_data)
def _cal_statistics(self, dataset, env, eps=1e-3, add_action_buffer=True):
"""
Overview:
Calculate the statistics of the dataset.
Arguments:
- dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset.
- env (:obj:`gym.Env`): The environment.
- eps (:obj:`float`): Epsilon.
"""
self._mean = dataset['observations'].mean(0)
self._std = dataset['observations'].std(0) + eps
action_max = dataset['actions'].max(0)
action_min = dataset['actions'].min(0)
if add_action_buffer:
action_buffer = 0.05 * (action_max - action_min)
action_max = (action_max + action_buffer).clip(max=env.action_space.high)
action_min = (action_min - action_buffer).clip(min=env.action_space.low)
self._action_bounds = np.stack([action_min, action_max], axis=0)
def _normalize_states(self, dataset):
"""
Overview:
Normalize the states.
Arguments:
- dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset.
"""
dataset['observations'] = (dataset['observations'] - self._mean) / self._std
dataset['next_observations'] = (dataset['next_observations'] - self._mean) / self._std
return dataset
@property
def mean(self):
"""
Overview:
Get the mean of the dataset.
"""
return self._mean
@property
def std(self):
"""
Overview:
Get the std of the dataset.
"""
return self._std
@property
def action_bounds(self) -> np.ndarray:
"""
Overview:
Get the action bounds of the dataset.
"""
return self._action_bounds
@property
def statistics(self) -> dict:
"""
Overview:
Get the statistics of the dataset.
"""
return DatasetStatistics(mean=self.mean, std=self.std, action_bounds=self.action_bounds)
@DATASET_REGISTRY.register('hdf5')
class HDF5Dataset(Dataset):
"""
Overview:
HDF5 dataset is saved in hdf5 format, which is used for offline RL algorithms.
The hdf5 format is a common format for storing large numerical arrays in Python.
For more details, please refer to https://support.hdfgroup.org/HDF5/.
Interfaces:
``__init__``, ``__len__``, ``__getitem__``
Properties:
- mean (:obj:`np.ndarray`): Mean of the dataset.
- std (:obj:`np.ndarray`): Std of the dataset.
- action_bounds (:obj:`np.ndarray`): Action bounds of the dataset.
- statistics (:obj:`dict`): Statistics of the dataset.
"""
def __init__(self, cfg: dict) -> None:
"""
Overview:
Initialization method.
Arguments:
- cfg (:obj:`dict`): Config dict.
"""
try:
import h5py
except ImportError:
import sys
logging.warning("not found h5py package, please install it trough `pip install h5py ")
sys.exit(1)
data_path = cfg.policy.collect.get('data_path', None)
if 'dataset' in cfg:
self.context_len = cfg.dataset.context_len
else:
self.context_len = 0
data = h5py.File(data_path, 'r')
self._load_data(data)
self._cal_statistics()
try:
if cfg.env.norm_obs.use_norm and cfg.env.norm_obs.offline_stats.use_offline_stats:
self._normalize_states()
except (KeyError, AttributeError):
# do not normalize
pass
def __len__(self) -> int:
"""
Overview:
Get the length of the dataset.
"""
return len(self._data['obs']) - self.context_len
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
Overview:
Get the item of the dataset.
Arguments:
- idx (:obj:`int`): The index of the dataset.
"""
if self.context_len == 0: # for other offline RL algorithms
return {k: self._data[k][idx] for k in self._data.keys()}
else: # for decision transformer
block_size = self.context_len
done_idx = idx + block_size
idx = done_idx - block_size
states = torch.as_tensor(
np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32
).view(block_size, -1)
actions = torch.as_tensor(self._data['action'][idx:done_idx], dtype=torch.long)
rtgs = torch.as_tensor(self._data['reward'][idx:done_idx, 0], dtype=torch.float32)
timesteps = torch.as_tensor(range(idx, done_idx), dtype=torch.int64)
traj_mask = torch.ones(self.context_len, dtype=torch.long)
return timesteps, states, actions, rtgs, traj_mask
def _load_data(self, dataset: Dict[str, np.ndarray]) -> None:
"""
Overview:
Load the dataset.
Arguments:
- dataset (:obj:`Dict[str, np.ndarray]`): The dataset.
"""
self._data = {}
for k in dataset.keys():
logging.info(f'Load {k} data.')
self._data[k] = dataset[k][:]
def _cal_statistics(self, eps: float = 1e-3):
"""
Overview:
Calculate the statistics of the dataset.
Arguments:
- eps (:obj:`float`): Epsilon.
"""
self._mean = self._data['obs'].mean(0)
self._std = self._data['obs'].std(0) + eps
action_max = self._data['action'].max(0)
action_min = self._data['action'].min(0)
buffer = 0.05 * (action_max - action_min)
action_max = action_max.astype(float) + buffer
action_min = action_max.astype(float) - buffer
self._action_bounds = np.stack([action_min, action_max], axis=0)
def _normalize_states(self):
"""
Overview:
Normalize the states.
"""
self._data['obs'] = (self._data['obs'] - self._mean) / self._std
self._data['next_obs'] = (self._data['next_obs'] - self._mean) / self._std
@property
def mean(self):
"""
Overview:
Get the mean of the dataset.
"""
return self._mean
@property
def std(self):
"""
Overview:
Get the std of the dataset.
"""
return self._std
@property
def action_bounds(self) -> np.ndarray:
"""
Overview:
Get the action bounds of the dataset.
"""
return self._action_bounds
@property
def statistics(self) -> dict:
"""
Overview:
Get the statistics of the dataset.
"""
return DatasetStatistics(mean=self.mean, std=self.std, action_bounds=self.action_bounds)
@DATASET_REGISTRY.register('d4rl_trajectory')
class D4RLTrajectoryDataset(Dataset):
"""
Overview:
D4RL trajectory dataset, which is used for offline RL algorithms.
Interfaces:
``__init__``, ``__len__``, ``__getitem__``
"""
# from infos.py from official d4rl github repo
REF_MIN_SCORE = {
'halfcheetah': -280.178953,
'walker2d': 1.629008,
'hopper': -20.272305,
}
REF_MAX_SCORE = {
'halfcheetah': 12135.0,
'walker2d': 4592.3,
'hopper': 3234.3,
}
# calculated from d4rl datasets
D4RL_DATASET_STATS = {
'halfcheetah-medium-v2': {
'state_mean': [
-0.06845773756504059, 0.016414547339081764, -0.18354906141757965, -0.2762460708618164,
-0.34061527252197266, -0.09339715540409088, -0.21321271359920502, -0.0877423882484436,
5.173007488250732, -0.04275195300579071, -0.036108363419771194, 0.14053793251514435,
0.060498327016830444, 0.09550975263118744, 0.06739100068807602, 0.005627387668937445,
0.013382787816226482
],
'state_std': [
0.07472999393939972, 0.3023499846458435, 0.30207309126853943, 0.34417077898979187, 0.17619241774082184,
0.507205605506897, 0.2567007839679718, 0.3294812738895416, 1.2574149370193481, 0.7600541710853577,
1.9800915718078613, 6.565362453460693, 7.466367721557617, 4.472222805023193, 10.566964149475098,
5.671932697296143, 7.4982590675354
]
},
'halfcheetah-medium-replay-v2': {
'state_mean': [
-0.12880703806877136, 0.3738119602203369, -0.14995987713336945, -0.23479078710079193,
-0.2841278612613678, -0.13096535205841064, -0.20157982409000397, -0.06517726927995682,
3.4768247604370117, -0.02785065770149231, -0.015035249292850494, 0.07697279006242752,
0.01266712136566639, 0.027325302362442017, 0.02316424623131752, 0.010438721626996994,
-0.015839405357837677
],
'state_std': [
0.17019015550613403, 1.284424901008606, 0.33442774415016174, 0.3672759234905243, 0.26092398166656494,
0.4784106910228729, 0.3181420564651489, 0.33552637696266174, 2.0931615829467773, 0.8037433624267578,
1.9044333696365356, 6.573209762573242, 7.572863578796387, 5.069749355316162, 9.10555362701416,
6.085654258728027, 7.25300407409668
]
},
'halfcheetah-medium-expert-v2': {
'state_mean': [
-0.05667462572455406, 0.024369969964027405, -0.061670560389757156, -0.22351515293121338,
-0.2675151228904724, -0.07545716315507889, -0.05809682980179787, -0.027675075456500053,
8.110626220703125, -0.06136331334710121, -0.17986927926540375, 0.25175222754478455, 0.24186332523822784,
0.2519369423389435, 0.5879552960395813, -0.24090635776519775, -0.030184272676706314
],
'state_std': [
0.06103534251451492, 0.36054104566574097, 0.45544400811195374, 0.38476887345314026, 0.2218363732099533,
0.5667523741722107, 0.3196682929992676, 0.2852923572063446, 3.443821907043457, 0.6728139519691467,
1.8616976737976074, 9.575807571411133, 10.029894828796387, 5.903450012207031, 12.128185272216797,
6.4811787605285645, 6.378620147705078
]
},
'walker2d-medium-v2': {
'state_mean': [
1.218966007232666, 0.14163373410701752, -0.03704913705587387, -0.13814310729503632, 0.5138224363327026,
-0.04719110205769539, -0.47288352251052856, 0.042254164814949036, 2.3948874473571777,
-0.03143199160695076, 0.04466355964541435, -0.023907244205474854, -0.1013401448726654,
0.09090937674045563, -0.004192637279629707, -0.12120571732521057, -0.5497063994407654
],
'state_std': [
0.12311358004808426, 0.3241879940032959, 0.11456084251403809, 0.2623065710067749, 0.5640279054641724,
0.2271878570318222, 0.3837319612503052, 0.7373676896095276, 1.2387926578521729, 0.798020601272583,
1.5664079189300537, 1.8092705011367798, 3.025604248046875, 4.062486171722412, 1.4586567878723145,
3.7445690631866455, 5.5851287841796875
]
},
'walker2d-medium-replay-v2': {
'state_mean': [
1.209364652633667, 0.13264022767543793, -0.14371201395988464, -0.2046516090631485, 0.5577612519264221,
-0.03231537342071533, -0.2784661054611206, 0.19130706787109375, 1.4701707363128662,
-0.12504704296588898, 0.0564953051507473, -0.09991033375263214, -0.340340256690979, 0.03546293452382088,
-0.08934258669614792, -0.2992438077926636, -0.5984178185462952
],
'state_std': [
0.11929835379123688, 0.3562574088573456, 0.25852200388908386, 0.42075422406196594, 0.5202291011810303,
0.15685082972049713, 0.36770978569984436, 0.7161387801170349, 1.3763766288757324, 0.8632221817970276,
2.6364643573760986, 3.0134117603302, 3.720684051513672, 4.867283821105957, 2.6681625843048096,
3.845186948776245, 5.4768385887146
]
},
'walker2d-medium-expert-v2': {
'state_mean': [
1.2294334173202515, 0.16869689524173737, -0.07089081406593323, -0.16197483241558075,
0.37101927399635315, -0.012209027074277401, -0.42461398243904114, 0.18986578285694122,
3.162475109100342, -0.018092676997184753, 0.03496946766972542, -0.013921679928898811,
-0.05937029421329498, -0.19549426436424255, -0.0019200450042262673, -0.062483321875333786,
-0.27366524934768677
],
'state_std': [
0.09932824969291687, 0.25981399416923523, 0.15062759816646576, 0.24249176681041718, 0.6758718490600586,
0.1650741547346115, 0.38140663504600525, 0.6962361335754395, 1.3501490354537964, 0.7641991376876831,
1.534574270248413, 2.1785972118377686, 3.276582717895508, 4.766193866729736, 1.1716983318328857,
4.039782524108887, 5.891613960266113
]
},
'hopper-medium-v2': {
'state_mean': [
1.311279058456421, -0.08469521254301071, -0.5382719039916992, -0.07201576232910156, 0.04932365566492081,
2.1066856384277344, -0.15017354488372803, 0.008783451281487942, -0.2848185896873474,
-0.18540096282958984, -0.28461286425590515
],
'state_std': [
0.17790751159191132, 0.05444620922207832, 0.21297138929367065, 0.14530418813228607, 0.6124444007873535,
0.8517446517944336, 1.4515252113342285, 0.6751695871353149, 1.5362390279769897, 1.616074562072754,
5.607253551483154
]
},
'hopper-medium-replay-v2': {
'state_mean': [
1.2305138111114502, -0.04371410980820656, -0.44542956352233887, -0.09370097517967224,
0.09094487875699997, 1.3694725036621094, -0.19992674887180328, -0.022861352190375328,
-0.5287045240402222, -0.14465883374214172, -0.19652697443962097
],
'state_std': [
0.1756512075662613, 0.0636928603053093, 0.3438323438167572, 0.19566889107227325, 0.5547984838485718,
1.051029920578003, 1.158307671546936, 0.7963128685951233, 1.4802359342575073, 1.6540331840515137,
5.108601093292236
]
},
'hopper-medium-expert-v2': {
'state_mean': [
1.3293815851211548, -0.09836531430482864, -0.5444297790527344, -0.10201650857925415,
0.02277466468513012, 2.3577215671539307, -0.06349576264619827, -0.00374026270583272,
-0.1766270101070404, -0.11862941086292267, -0.12097819894552231
],
'state_std': [
0.17012375593185425, 0.05159067362546921, 0.18141433596611023, 0.16430604457855225, 0.6023368239402771,
0.7737284898757935, 1.4986555576324463, 0.7483318448066711, 1.7953159809112549, 2.0530025959014893,
5.725032806396484
]
},
}
def __init__(self, cfg: dict) -> None:
"""
Overview:
Initialization method.
Arguments:
- cfg (:obj:`dict`): Config dict.
"""
dataset_path = cfg.dataset.data_dir_prefix
rtg_scale = cfg.dataset.rtg_scale
self.context_len = cfg.dataset.context_len
self.env_type = cfg.dataset.env_type
if 'hdf5' in dataset_path: # for mujoco env
try:
import h5py
import collections
except ImportError:
import sys
logging.warning("not found h5py package, please install it trough `pip install h5py ")
sys.exit(1)
dataset = h5py.File(dataset_path, 'r')
N = dataset['rewards'].shape[0]
data_ = collections.defaultdict(list)
use_timeouts = False
if 'timeouts' in dataset:
use_timeouts = True
episode_step = 0
paths = []
for i in range(N):
done_bool = bool(dataset['terminals'][i])
if use_timeouts:
final_timestep = dataset['timeouts'][i]
else:
final_timestep = (episode_step == 1000 - 1)
for k in ['observations', 'actions', 'rewards', 'terminals']:
data_[k].append(dataset[k][i])
if done_bool or final_timestep:
episode_step = 0
episode_data = {}
for k in data_:
episode_data[k] = np.array(data_[k])
paths.append(episode_data)
data_ = collections.defaultdict(list)
episode_step += 1
self.trajectories = paths
# calculate state mean and variance and returns_to_go for all traj
states = []
for traj in self.trajectories:
traj_len = traj['observations'].shape[0]
states.append(traj['observations'])
# calculate returns to go and rescale them
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale
# used for input normalization
states = np.concatenate(states, axis=0)
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
# normalize states
for traj in self.trajectories:
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std
elif 'pkl' in dataset_path:
if 'dqn' in dataset_path:
# load dataset
with open(dataset_path, 'rb') as f:
self.trajectories = pickle.load(f)
if isinstance(self.trajectories[0], list):
# for our collected dataset, e.g. cartpole/lunarlander case
trajectories_tmp = []
original_keys = ['obs', 'next_obs', 'action', 'reward']
keys = ['observations', 'next_observations', 'actions', 'rewards']
trajectories_tmp = [
{
key: np.stack(
[
self.trajectories[eps_index][transition_index][o_key]
for transition_index in range(len(self.trajectories[eps_index]))
],
axis=0
)
for key, o_key in zip(keys, original_keys)
} for eps_index in range(len(self.trajectories))
]
self.trajectories = trajectories_tmp
states = []
for traj in self.trajectories:
# traj_len = traj['observations'].shape[0]
states.append(traj['observations'])
# calculate returns to go and rescale them
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale
# used for input normalization
states = np.concatenate(states, axis=0)
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
# normalize states
for traj in self.trajectories:
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std
else:
# load dataset
with open(dataset_path, 'rb') as f:
self.trajectories = pickle.load(f)
states = []
for traj in self.trajectories:
states.append(traj['observations'])
# calculate returns to go and rescale them
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale
# used for input normalization
states = np.concatenate(states, axis=0)
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
# normalize states
for traj in self.trajectories:
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std
else:
# -- load data from memory (make more efficient)
obss = []
actions = []
returns = [0]
done_idxs = []
stepwise_returns = []
transitions_per_buffer = np.zeros(50, dtype=int)
num_trajectories = 0
while len(obss) < cfg.dataset.num_steps:
buffer_num = np.random.choice(np.arange(50 - cfg.dataset.num_buffers, 50), 1)[0]
i = transitions_per_buffer[buffer_num]
frb = FixedReplayBuffer(
data_dir=cfg.dataset.data_dir_prefix + '/1/replay_logs',
replay_suffix=buffer_num,
observation_shape=(84, 84),
stack_size=4,
update_horizon=1,
gamma=0.99,
observation_dtype=np.uint8,
batch_size=32,
replay_capacity=100000
)
if frb._loaded_buffers:
done = False
curr_num_transitions = len(obss)
trajectories_to_load = cfg.dataset.trajectories_per_buffer
while not done:
states, ac, ret, next_states, next_action, next_reward, terminal, indices = \
frb.sample_transition_batch(batch_size=1, indices=[i])
states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84)
obss.append(states)
actions.append(ac[0])
stepwise_returns.append(ret[0])
if terminal[0]:
done_idxs.append(len(obss))
returns.append(0)
if trajectories_to_load == 0:
done = True
else:
trajectories_to_load -= 1
returns[-1] += ret[0]
i += 1
if i >= 100000:
obss = obss[:curr_num_transitions]
actions = actions[:curr_num_transitions]
stepwise_returns = stepwise_returns[:curr_num_transitions]
returns[-1] = 0
i = transitions_per_buffer[buffer_num]
done = True
num_trajectories += (cfg.dataset.trajectories_per_buffer - trajectories_to_load)
transitions_per_buffer[buffer_num] = i
actions = np.array(actions)
returns = np.array(returns)
stepwise_returns = np.array(stepwise_returns)
done_idxs = np.array(done_idxs)
# -- create reward-to-go dataset
start_index = 0
rtg = np.zeros_like(stepwise_returns)
for i in done_idxs:
i = int(i)
curr_traj_returns = stepwise_returns[start_index:i]
for j in range(i - 1, start_index - 1, -1): # start from i-1
rtg_j = curr_traj_returns[j - start_index:i - start_index]
rtg[j] = sum(rtg_j)
start_index = i
# -- create timestep dataset
start_index = 0
timesteps = np.zeros(len(actions) + 1, dtype=int)
for i in done_idxs:
i = int(i)
timesteps[start_index:i + 1] = np.arange(i + 1 - start_index)
start_index = i + 1
self.obss = obss
self.actions = actions
self.done_idxs = done_idxs
self.rtgs = rtg
self.timesteps = timesteps
# return obss, actions, returns, done_idxs, rtg, timesteps
def get_max_timestep(self) -> int:
"""
Overview:
Get the max timestep of the dataset.
"""
return max(self.timesteps)
def get_state_stats(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Overview:
Get the state mean and std of the dataset.
"""
return deepcopy(self.state_mean), deepcopy(self.state_std)
def get_d4rl_dataset_stats(self, env_d4rl_name: str) -> Dict[str, list]:
"""
Overview:
Get the d4rl dataset stats.
Arguments:
- env_d4rl_name (:obj:`str`): The d4rl env name.
"""
return self.D4RL_DATASET_STATS[env_d4rl_name]
def __len__(self) -> int:
"""
Overview:
Get the length of the dataset.
"""
if self.env_type != 'atari':
return len(self.trajectories)
else:
return len(self.obss) - self.context_len
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Overview:
Get the item of the dataset.
Arguments:
- idx (:obj:`int`): The index of the dataset.
"""
if self.env_type != 'atari':
traj = self.trajectories[idx]
traj_len = traj['observations'].shape[0]
if traj_len > self.context_len:
# sample random index to slice trajectory
si = np.random.randint(0, traj_len - self.context_len)
states = torch.from_numpy(traj['observations'][si:si + self.context_len])
actions = torch.from_numpy(traj['actions'][si:si + self.context_len])
returns_to_go = torch.from_numpy(traj['returns_to_go'][si:si + self.context_len])
timesteps = torch.arange(start=si, end=si + self.context_len, step=1)
# all ones since no padding
traj_mask = torch.ones(self.context_len, dtype=torch.long)
else:
padding_len = self.context_len - traj_len
# padding with zeros
states = torch.from_numpy(traj['observations'])
states = torch.cat(
[states, torch.zeros(([padding_len] + list(states.shape[1:])), dtype=states.dtype)], dim=0
)
actions = torch.from_numpy(traj['actions'])
actions = torch.cat(
[actions, torch.zeros(([padding_len] + list(actions.shape[1:])), dtype=actions.dtype)], dim=0
)
returns_to_go = torch.from_numpy(traj['returns_to_go'])
returns_to_go = torch.cat(
[
returns_to_go,
torch.zeros(([padding_len] + list(returns_to_go.shape[1:])), dtype=returns_to_go.dtype)
],
dim=0
)
timesteps = torch.arange(start=0, end=self.context_len, step=1)
traj_mask = torch.cat(
[torch.ones(traj_len, dtype=torch.long),
torch.zeros(padding_len, dtype=torch.long)], dim=0
)
return timesteps, states, actions, returns_to_go, traj_mask
else: # mean cost less than 0.001s
block_size = self.context_len
done_idx = idx + block_size
for i in self.done_idxs:
if i > idx: # first done_idx greater than idx
done_idx = min(int(i), done_idx)
break
idx = done_idx - block_size
states = torch.as_tensor(
np.array(self.obss[idx:done_idx]), dtype=torch.float32
).view(block_size, -1) # (block_size, 4*84*84)
states = states / 255.
actions = torch.as_tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1)
rtgs = torch.as_tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1)
timesteps = torch.as_tensor(self.timesteps[idx:idx + 1], dtype=torch.int64).unsqueeze(1)
traj_mask = torch.ones(self.context_len, dtype=torch.long)
return timesteps, states, actions, rtgs, traj_mask
@DATASET_REGISTRY.register('d4rl_diffuser')
class D4RLDiffuserDataset(Dataset):
"""
Overview:
D4RL diffuser dataset, which is used for offline RL algorithms.
Interfaces:
``__init__``, ``__len__``, ``__getitem__``
"""
def __init__(self, dataset_path: str, context_len: int, rtg_scale: float) -> None:
"""
Overview:
Initialization method of D4RLDiffuserDataset.
Arguments:
- dataset_path (:obj:`str`): The dataset path.
- context_len (:obj:`int`): The length of the context.
- rtg_scale (:obj:`float`): The scale of the returns to go.
"""
self.context_len = context_len
# load dataset
with open(dataset_path, 'rb') as f:
self.trajectories = pickle.load(f)
if isinstance(self.trajectories[0], list):
# for our collected dataset, e.g. cartpole/lunarlander case
trajectories_tmp = []
original_keys = ['obs', 'next_obs', 'action', 'reward']
keys = ['observations', 'next_observations', 'actions', 'rewards']
for key, o_key in zip(keys, original_keys):
trajectories_tmp = [
{
key: np.stack(
[
self.trajectories[eps_index][transition_index][o_key]
for transition_index in range(len(self.trajectories[eps_index]))
],
axis=0
)
} for eps_index in range(len(self.trajectories))
]
self.trajectories = trajectories_tmp
states = []
for traj in self.trajectories:
traj_len = traj['observations'].shape[0]
states.append(traj['observations'])
# calculate returns to go and rescale them
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale
# used for input normalization
states = np.concatenate(states, axis=0)
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
# normalize states
for traj in self.trajectories:
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std
class FixedReplayBuffer(object):
"""
Overview:
Object composed of a list of OutofGraphReplayBuffers.
Interfaces:
``__init__``, ``get_transition_elements``, ``sample_transition_batch``
"""
def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
"""
Overview:
Initialize the FixedReplayBuffer class.
Arguments:
- data_dir (:obj:`str`): log Directory from which to load the replay buffer.
- replay_suffix (:obj:`int`): If not None, then only load the replay buffer \
corresponding to the specific suffix in data directory.
- args (:obj:`list`): Arbitrary extra arguments.
- kwargs (:obj:`dict`): Arbitrary keyword arguments.
"""
self._args = args
self._kwargs = kwargs
self._data_dir = data_dir
self._loaded_buffers = False
self.add_count = np.array(0)
self._replay_suffix = replay_suffix
if not self._loaded_buffers:
if replay_suffix is not None:
assert replay_suffix >= 0, 'Please pass a non-negative replay suffix'
self.load_single_buffer(replay_suffix)
else:
pass
# self._load_replay_buffers(num_buffers=50)
def load_single_buffer(self, suffix):
"""
Overview:
Load a single replay buffer.
Arguments:
- suffix (:obj:`int`): The suffix of the replay buffer.
"""
replay_buffer = self._load_buffer(suffix)
if replay_buffer is not None:
self._replay_buffers = [replay_buffer]
self.add_count = replay_buffer.add_count
self._num_replay_buffers = 1
self._loaded_buffers = True
def _load_buffer(self, suffix):
"""
Overview:
Loads a OutOfGraphReplayBuffer replay buffer.
Arguments:
- suffix (:obj:`int`): The suffix of the replay buffer.
"""
try:
from dopamine.replay_memory import circular_replay_buffer
STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX
# pytype: disable=attribute-error
replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(*self._args, **self._kwargs)
replay_buffer.load(self._data_dir, suffix)
# pytype: enable=attribute-error
return replay_buffer
# except tf.errors.NotFoundError:
except:
raise ('can not load')
def get_transition_elements(self):
"""
Overview:
Returns the transition elements.
"""
return self._replay_buffers[0].get_transition_elements()
def sample_transition_batch(self, batch_size=None, indices=None):
"""
Overview:
Returns a batch of transitions (including any extra contents).
Arguments:
- batch_size (:obj:`int`): The batch size.
- indices (:obj:`list`): The indices of the batch.
"""
buffer_index = np.random.randint(self._num_replay_buffers)
return self._replay_buffers[buffer_index].sample_transition_batch(batch_size=batch_size, indices=indices)
class PCDataset(Dataset):
"""
Overview:
Dataset for Procedure Cloning.
Interfaces:
``__init__``, ``__len__``, ``__getitem__``
"""
def __init__(self, all_data):
"""
Overview:
Initialization method of PCDataset.
Arguments:
- all_data (:obj:`tuple`): The tuple of all data.
"""
self._data = all_data
def __getitem__(self, item):
"""
Overview:
Get the item of the dataset.
Arguments:
- item (:obj:`int`): The index of the dataset.
"""
return {'obs': self._data[0][item], 'bfs_in': self._data[1][item], 'bfs_out': self._data[2][item]}
def __len__(self):
"""
Overview:
Get the length of the dataset.
"""
return self._data[0].shape[0]
def load_bfs_datasets(train_seeds=1, test_seeds=5):
"""
Overview:
Load BFS datasets.
Arguments:
- train_seeds (:obj:`int`): The number of train seeds.
- test_seeds (:obj:`int`): The number of test seeds.
"""
from dizoo.maze.envs import Maze
def load_env(seed):
ccc = easydict.EasyDict({'size': 16})
e = Maze(ccc)
e.seed(seed)
e.reset()
return e
envs = [load_env(i) for i in range(train_seeds + test_seeds)]
observations_train = []
observations_test = []
bfs_input_maps_train = []
bfs_input_maps_test = []
bfs_output_maps_train = []
bfs_output_maps_test = []
for idx, env in enumerate(envs):
if idx < train_seeds:
observations = observations_train
bfs_input_maps = bfs_input_maps_train
bfs_output_maps = bfs_output_maps_train
else:
observations = observations_test
bfs_input_maps = bfs_input_maps_test
bfs_output_maps = bfs_output_maps_test
start_obs = env.process_states(env._get_obs(), env.get_maze_map())
_, track_back = get_vi_sequence(env, start_obs)
env_observations = torch.stack([track_back[i][0] for i in range(len(track_back))], dim=0)
for i in range(env_observations.shape[0]):
bfs_sequence, _ = get_vi_sequence(env, env_observations[i].numpy().astype(np.int32)) # [L, W, W]
bfs_input_map = env.n_action * np.ones([env.size, env.size], dtype=np.long)
for j in range(bfs_sequence.shape[0]):
bfs_input_maps.append(torch.from_numpy(bfs_input_map))
bfs_output_maps.append(torch.from_numpy(bfs_sequence[j]))
observations.append(env_observations[i])
bfs_input_map = bfs_sequence[j]
train_data = PCDataset(
(
torch.stack(observations_train, dim=0),
torch.stack(bfs_input_maps_train, dim=0),
torch.stack(bfs_output_maps_train, dim=0),
)
)
test_data = PCDataset(
(
torch.stack(observations_test, dim=0),
torch.stack(bfs_input_maps_test, dim=0),
torch.stack(bfs_output_maps_test, dim=0),
)
)
return train_data, test_data
@DATASET_REGISTRY.register('bco')
class BCODataset(Dataset):
"""
Overview:
Dataset for Behavioral Cloning from Observation.
Interfaces:
``__init__``, ``__len__``, ``__getitem__``
Properties:
- obs (:obj:`np.ndarray`): The observation array.
- action (:obj:`np.ndarray`): The action array.
"""
def __init__(self, data=None):
"""
Overview:
Initialization method of BCODataset.
Arguments:
- data (:obj:`dict`): The data dict.
"""
if data is None:
raise ValueError('Dataset can not be empty!')
else:
self._data = data
def __len__(self):
"""
Overview:
Get the length of the dataset.
"""
return len(self._data['obs'])
def __getitem__(self, idx):
"""
Overview:
Get the item of the dataset.
Arguments:
- idx (:obj:`int`): The index of the dataset.
"""
return {k: self._data[k][idx] for k in self._data.keys()}
@property
def obs(self):
"""
Overview:
Get the observation array.
"""
return self._data['obs']
@property
def action(self):
"""
Overview:
Get the action array.
"""
return self._data['action']
@DATASET_REGISTRY.register('diffuser_traj')
class SequenceDataset(torch.utils.data.Dataset):
"""
Overview:
Dataset for diffuser.
Interfaces:
``__init__``, ``__len__``, ``__getitem__``
"""
def __init__(self, cfg):
"""
Overview:
Initialization method of SequenceDataset.
Arguments:
- cfg (:obj:`dict`): The config dict.
"""
import gym
env_id = cfg.env.env_id
data_path = cfg.policy.collect.get('data_path', None)
env = gym.make(env_id)
dataset = env.get_dataset()
self.returns_scale = cfg.env.returns_scale
self.horizon = cfg.env.horizon
self.max_path_length = cfg.env.max_path_length
self.discount = cfg.policy.learn.discount_factor
self.discounts = self.discount ** np.arange(self.max_path_length)[:, None]
self.use_padding = cfg.env.use_padding
self.include_returns = cfg.env.include_returns
self.env_id = cfg.env.env_id
itr = self.sequence_dataset(env, dataset)
self.n_episodes = 0
fields = {}
for k in dataset.keys():
if 'metadata' in k:
continue
fields[k] = []
fields['path_lengths'] = []
for i, episode in enumerate(itr):
path_length = len(episode['observations'])
assert path_length <= self.max_path_length
fields['path_lengths'].append(path_length)
for key, val in episode.items():
if key not in fields:
fields[key] = []
if val.ndim < 2:
val = np.expand_dims(val, axis=-1)
shape = (self.max_path_length, val.shape[-1])
arr = np.zeros(shape, dtype=np.float32)
arr[:path_length] = val
fields[key].append(arr)
if episode['terminals'].any() and cfg.env.termination_penalty and 'timeouts' in episode:
assert not episode['timeouts'].any(), 'Penalized a timeout episode for early termination'
fields['rewards'][-1][path_length - 1] += cfg.env.termination_penalty
self.n_episodes += 1
for k in fields.keys():
fields[k] = np.array(fields[k])
self.normalizer = DatasetNormalizer(fields, cfg.policy.normalizer, path_lengths=fields['path_lengths'])
self.indices = self.make_indices(fields['path_lengths'], self.horizon)
self.observation_dim = cfg.env.obs_dim
self.action_dim = cfg.env.action_dim
self.fields = fields
self.normalize()
self.normed = False
if cfg.env.normed:
self.vmin, self.vmax = self._get_bounds()
self.normed = True
# shapes = {key: val.shape for key, val in self.fields.items()}
# print(f'[ datasets/mujoco ] Dataset fields: {shapes}')
def sequence_dataset(self, env, dataset=None):
"""
Overview:
Sequence the dataset.
Arguments:
- env (:obj:`gym.Env`): The gym env.
"""
import collections
N = dataset['rewards'].shape[0]
if 'maze2d' in env.spec.id:
dataset = self.maze2d_set_terminals(env, dataset)
data_ = collections.defaultdict(list)
# The newer version of the dataset adds an explicit
# timeouts field. Keep old method for backwards compatability.
use_timeouts = 'timeouts' in dataset
episode_step = 0
for i in range(N):
done_bool = bool(dataset['terminals'][i])
if use_timeouts:
final_timestep = dataset['timeouts'][i]
else:
final_timestep = (episode_step == env._max_episode_steps - 1)
for k in dataset:
if 'metadata' in k:
continue
data_[k].append(dataset[k][i])
if done_bool or final_timestep:
episode_step = 0
episode_data = {}
for k in data_:
episode_data[k] = np.array(data_[k])
if 'maze2d' in env.spec.id:
episode_data = self.process_maze2d_episode(episode_data)
yield episode_data
data_ = collections.defaultdict(list)
episode_step += 1
def maze2d_set_terminals(self, env, dataset):
"""
Overview:
Set the terminals for maze2d.
Arguments:
- env (:obj:`gym.Env`): The gym env.
- dataset (:obj:`dict`): The dataset dict.
"""
goal = env.get_target()
threshold = 0.5
xy = dataset['observations'][:, :2]
distances = np.linalg.norm(xy - goal, axis=-1)
at_goal = distances < threshold
timeouts = np.zeros_like(dataset['timeouts'])
# timeout at time t iff
# at goal at time t and
# not at goal at time t + 1
timeouts[:-1] = at_goal[:-1] * ~at_goal[1:]
timeout_steps = np.where(timeouts)[0]
path_lengths = timeout_steps[1:] - timeout_steps[:-1]
print(
f'[ utils/preprocessing ] Segmented {env.spec.id} | {len(path_lengths)} paths | '
f'min length: {path_lengths.min()} | max length: {path_lengths.max()}'
)
dataset['timeouts'] = timeouts
return dataset
def process_maze2d_episode(self, episode):
"""
Overview:
Process the maze2d episode, adds in `next_observations` field to episode.
Arguments:
- episode (:obj:`dict`): The episode dict.
"""
assert 'next_observations' not in episode
length = len(episode['observations'])
next_observations = episode['observations'][1:].copy()
for key, val in episode.items():
episode[key] = val[:-1]
episode['next_observations'] = next_observations
return episode
def normalize(self, keys=['observations', 'actions']):
"""
Overview:
Normalize the dataset, normalize fields that will be predicted by the diffusion model
Arguments:
- keys (:obj:`list`): The list of keys.
"""
for key in keys:
array = self.fields[key].reshape(self.n_episodes * self.max_path_length, -1)
normed = self.normalizer.normalize(array, key)
self.fields[f'normed_{key}'] = normed.reshape(self.n_episodes, self.max_path_length, -1)
def make_indices(self, path_lengths, horizon):
"""
Overview:
Make indices for sampling from dataset. Each index maps to a datapoint.
Arguments:
- path_lengths (:obj:`np.ndarray`): The path length array.
- horizon (:obj:`int`): The horizon.
"""
indices = []
for i, path_length in enumerate(path_lengths):
max_start = min(path_length - 1, self.max_path_length - horizon)
if not self.use_padding:
max_start = min(max_start, path_length - horizon)
for start in range(max_start):
end = start + horizon
indices.append((i, start, end))
indices = np.array(indices)
return indices
def get_conditions(self, observations):
"""
Overview:
Get the conditions on current observation for planning.
Arguments:
- observations (:obj:`np.ndarray`): The observation array.
"""
if 'maze2d' in self.env_id:
return {'condition_id': [0, self.horizon - 1], 'condition_val': [observations[0], observations[-1]]}
else:
return {'condition_id': [0], 'condition_val': [observations[0]]}
def __len__(self):
"""
Overview:
Get the length of the dataset.
"""
return len(self.indices)
def _get_bounds(self):
"""
Overview:
Get the bounds of the dataset.
"""
print('[ datasets/sequence ] Getting value dataset bounds...', end=' ', flush=True)
vmin = np.inf
vmax = -np.inf
for i in range(len(self.indices)):
value = self.__getitem__(i)['returns'].item()
vmin = min(value, vmin)
vmax = max(value, vmax)
print('✓')
return vmin, vmax
def normalize_value(self, value):
"""
Overview:
Normalize the value.
Arguments:
- value (:obj:`np.ndarray`): The value array.
"""
# [0, 1]
normed = (value - self.vmin) / (self.vmax - self.vmin)
# [-1, 1]
normed = normed * 2 - 1
return normed
def __getitem__(self, idx, eps=1e-4):
"""
Overview:
Get the item of the dataset.
Arguments:
- idx (:obj:`int`): The index of the dataset.
- eps (:obj:`float`): The epsilon.
"""
path_ind, start, end = self.indices[idx]
observations = self.fields['normed_observations'][path_ind, start:end]
actions = self.fields['normed_actions'][path_ind, start:end]
done = self.fields['terminals'][path_ind, start:end]
# conditions = self.get_conditions(observations)
trajectories = np.concatenate([actions, observations], axis=-1)
if self.include_returns:
rewards = self.fields['rewards'][path_ind, start:]
discounts = self.discounts[:len(rewards)]
returns = (discounts * rewards).sum()
if self.normed:
returns = self.normalize_value(returns)
returns = np.array([returns / self.returns_scale], dtype=np.float32)
batch = {
'trajectories': trajectories,
'returns': returns,
'done': done,
'action': actions,
}
else:
batch = {
'trajectories': trajectories,
'done': done,
'action': actions,
}
batch.update(self.get_conditions(observations))
return batch
def hdf5_save(exp_data, expert_data_path):
"""
Overview:
Save the data to hdf5.
"""
try:
import h5py
except ImportError:
import sys
logging.warning("not found h5py package, please install it trough 'pip install h5py' ")
sys.exit(1)
dataset = dataset = h5py.File('%s_demos.hdf5' % expert_data_path.replace('.pkl', ''), 'w')
dataset.create_dataset('obs', data=np.array([d['obs'].numpy() for d in exp_data]), compression='gzip')
dataset.create_dataset('action', data=np.array([d['action'].numpy() for d in exp_data]), compression='gzip')
dataset.create_dataset('reward', data=np.array([d['reward'].numpy() for d in exp_data]), compression='gzip')
dataset.create_dataset('done', data=np.array([d['done'] for d in exp_data]), compression='gzip')
dataset.create_dataset('next_obs', data=np.array([d['next_obs'].numpy() for d in exp_data]), compression='gzip')
def naive_save(exp_data, expert_data_path):
"""
Overview:
Save the data to pickle.
"""
with open(expert_data_path, 'wb') as f:
pickle.dump(exp_data, f)
def offline_data_save_type(exp_data, expert_data_path, data_type='naive'):
"""
Overview:
Save the offline data.
"""
globals()[data_type + '_save'](exp_data, expert_data_path)
def create_dataset(cfg, **kwargs) -> Dataset:
"""
Overview:
Create dataset.
"""
cfg = EasyDict(cfg)
import_module(cfg.get('import_names', []))
return DATASET_REGISTRY.build(cfg.policy.collect.data_type, cfg=cfg, **kwargs)