File size: 3,280 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import copy

import pytest
import torch
from easydict import EasyDict
from ding.policy.cql import CQLPolicy, DiscreteCQLPolicy
from ding.utils.data import offline_data_save_type
from tensorboardX import SummaryWriter
from ding.model.wrapper.model_wrappers import ArgmaxSampleWrapper, EpsGreedySampleWrapper, TargetNetworkWrapper
import os
from typing import List
from collections import namedtuple
from ding.utils import deep_merge_dicts

obs_space = 5
action_space = 3

cfg1 = EasyDict(CQLPolicy.default_config())
cfg1.model.obs_shape = obs_space
cfg1.model.action_shape = action_space

cfg2 = copy.deepcopy(cfg1)
cfg2.learn.auto_alpha = False
cfg2.learn.log_space = False

cfg3 = EasyDict(DiscreteCQLPolicy.default_config())
cfg3.model = {}
cfg3.model.obs_shape = obs_space
cfg3.model.action_shape = action_space

cfg4 = copy.deepcopy(cfg3)
cfg4.learn.auto_alpha = False


def get_batch(size=8):
    data = {}
    for i in range(size):
        obs = torch.zeros(obs_space)
        data[i] = obs
    return data


def get_transition(size=20):
    data = []
    for i in range(size):
        sample = {}
        sample['obs'] = torch.zeros(obs_space)
        sample['action'] = torch.zeros(action_space)
        sample['done'] = False
        sample['next_obs'] = torch.zeros(obs_space)
        sample['reward'] = torch.Tensor([1.])
        data.append(sample)
    return data


def get_transition_batch(bs=1):
    sample = {}
    sample['obs'] = torch.zeros(bs, obs_space)
    sample['action'] = torch.zeros(bs, action_space)
    return sample


@pytest.mark.parametrize('cfg', [cfg1, cfg2])
@pytest.mark.unittest
def test_cql_continuous(cfg):
    policy = CQLPolicy(cfg, enable_field=['collect', 'eval', 'learn'])
    assert type(policy._target_model) == TargetNetworkWrapper
    q_value = policy._get_q_value(get_transition_batch(cfg.learn.num_actions))
    assert q_value[0].shape[-1] == 1 and q_value[0].shape[-2] == cfg.learn.num_actions
    act, log_prob = policy._get_policy_actions(get_transition_batch(cfg.learn.num_actions))
    assert list(act.shape) == [cfg.learn.num_actions * 10, action_space]
    sample = get_transition(size=20)
    out = policy._forward_learn(sample)


def get_transition_discrete(size=20):
    data = []
    for i in range(size):
        sample = {}
        sample['obs'] = torch.zeros(obs_space)
        sample['action'] = torch.tensor(i % action_space)
        sample['done'] = False
        sample['next_obs'] = torch.zeros(obs_space)
        sample['reward'] = torch.Tensor([1.])
        data.append(sample)
    return data


@pytest.mark.parametrize('cfg', [cfg3, cfg4])
@pytest.mark.unittest
def test_cql_discrete(cfg):
    policy = DiscreteCQLPolicy(cfg, enable_field=['collect', 'eval', 'learn'])
    assert type(policy._learn_model) == ArgmaxSampleWrapper
    assert type(policy._target_model) == TargetNetworkWrapper
    assert type(policy._collect_model) == EpsGreedySampleWrapper
    sample = get_transition_batch(bs=20)
    samples = policy._get_train_sample(sample)
    assert len(samples['obs']) == 20
    state = policy._state_dict_learn()
    policy._load_state_dict_learn(state)
    sample = get_transition_discrete(size=1)
    out = policy._forward_learn(sample)
    out = policy._forward_collect(get_batch(size=8), eps=0.1)