File size: 2,395 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import pytest
import torch
from easydict import EasyDict
import os
from ding.utils.data import offline_data_save_type, create_dataset, NaiveRLDataset, D4RLDataset, HDF5Dataset

cfg1 = dict(policy=dict(collect=dict(
    data_type='naive',
    data_path='./expert.pkl',
), ))

cfg2 = dict(
    env=dict(norm_obs=dict(use_norm=True, offline_stats=dict(use_offline_stats=True))),
    policy=dict(collect=dict(data_type='hdf5', data_path='./expert_demos.hdf5')),
)

cfg3 = dict(env=dict(env_id='hopper-expert-v0'), policy=dict(collect=dict(data_type='d4rl', ), ))

cfgs = [cfg1, cfg2]  # cfg3
unittest_args = ['naive', 'hdf5']

# fake transition & data
transition = {}
transition['obs'] = torch.zeros((3, 1))
transition['next_obs'] = torch.zeros((3, 1))
transition['action'] = torch.zeros((1, 1))
transition['reward'] = torch.tensor((1, ))
transition['done'] = False
transition['collect_iter'] = 0

fake_data = [transition for i in range(32)]
expert_data_path = './expert.pkl'


@pytest.mark.parametrize('data_type', unittest_args)
@pytest.mark.unittest
def test_offline_data_save_type(data_type):
    offline_data_save_type(exp_data=fake_data, expert_data_path=expert_data_path, data_type=data_type)


@pytest.mark.parametrize('cfg', cfgs)
@pytest.mark.unittest
def test_dataset(cfg):
    cfg = EasyDict(cfg)
    create_dataset(cfg)


@pytest.mark.parametrize('cfg', [cfg1])
@pytest.mark.unittest
def test_NaiveRLDataset(cfg):
    cfg = EasyDict(cfg)
    NaiveRLDataset(cfg)
    dataset = NaiveRLDataset(expert_data_path)
    assert type(len(dataset)) == int
    assert dataset[0] is not None


# @pytest.mark.parametrize('cfg', [cfg3])
# @pytest.mark.unittest
# def test_D4RLDataset(cfg):
#     cfg = EasyDict(cfg)
#     dataset = D4RLDataset(cfg)


@pytest.mark.parametrize('cfg', [cfg2])
@pytest.mark.unittest
def test_HDF5Dataset(cfg):
    cfg = EasyDict(cfg)
    dataset = HDF5Dataset(cfg)
    assert dataset.mean is not None and dataset.std[0] is not None
    assert dataset._data['obs'].mean(0)[0] == 0
    assert type(len(dataset)) == int
    assert dataset[0] is not None


@pytest.fixture(scope="session", autouse=True)
def cleanup(request):

    def remove_test_dir():
        if os.path.exists('./expert.pkl'):
            os.remove('./expert.pkl')
        if os.path.exists('./expert_demos.hdf5'):
            os.remove('./expert_demos.hdf5')

    request.addfinalizer(remove_test_dir)