|
import copy |
|
|
|
import pytest |
|
import torch |
|
from easydict import EasyDict |
|
from ding.policy.cql import CQLPolicy, DiscreteCQLPolicy |
|
from ding.utils.data import offline_data_save_type |
|
from tensorboardX import SummaryWriter |
|
from ding.model.wrapper.model_wrappers import ArgmaxSampleWrapper, EpsGreedySampleWrapper, TargetNetworkWrapper |
|
import os |
|
from typing import List |
|
from collections import namedtuple |
|
from ding.utils import deep_merge_dicts |
|
|
|
obs_space = 5 |
|
action_space = 3 |
|
|
|
cfg1 = EasyDict(CQLPolicy.default_config()) |
|
cfg1.model.obs_shape = obs_space |
|
cfg1.model.action_shape = action_space |
|
|
|
cfg2 = copy.deepcopy(cfg1) |
|
cfg2.learn.auto_alpha = False |
|
cfg2.learn.log_space = False |
|
|
|
cfg3 = EasyDict(DiscreteCQLPolicy.default_config()) |
|
cfg3.model = {} |
|
cfg3.model.obs_shape = obs_space |
|
cfg3.model.action_shape = action_space |
|
|
|
cfg4 = copy.deepcopy(cfg3) |
|
cfg4.learn.auto_alpha = False |
|
|
|
|
|
def get_batch(size=8): |
|
data = {} |
|
for i in range(size): |
|
obs = torch.zeros(obs_space) |
|
data[i] = obs |
|
return data |
|
|
|
|
|
def get_transition(size=20): |
|
data = [] |
|
for i in range(size): |
|
sample = {} |
|
sample['obs'] = torch.zeros(obs_space) |
|
sample['action'] = torch.zeros(action_space) |
|
sample['done'] = False |
|
sample['next_obs'] = torch.zeros(obs_space) |
|
sample['reward'] = torch.Tensor([1.]) |
|
data.append(sample) |
|
return data |
|
|
|
|
|
def get_transition_batch(bs=1): |
|
sample = {} |
|
sample['obs'] = torch.zeros(bs, obs_space) |
|
sample['action'] = torch.zeros(bs, action_space) |
|
return sample |
|
|
|
|
|
@pytest.mark.parametrize('cfg', [cfg1, cfg2]) |
|
@pytest.mark.unittest |
|
def test_cql_continuous(cfg): |
|
policy = CQLPolicy(cfg, enable_field=['collect', 'eval', 'learn']) |
|
assert type(policy._target_model) == TargetNetworkWrapper |
|
q_value = policy._get_q_value(get_transition_batch(cfg.learn.num_actions)) |
|
assert q_value[0].shape[-1] == 1 and q_value[0].shape[-2] == cfg.learn.num_actions |
|
act, log_prob = policy._get_policy_actions(get_transition_batch(cfg.learn.num_actions)) |
|
assert list(act.shape) == [cfg.learn.num_actions * 10, action_space] |
|
sample = get_transition(size=20) |
|
out = policy._forward_learn(sample) |
|
|
|
|
|
def get_transition_discrete(size=20): |
|
data = [] |
|
for i in range(size): |
|
sample = {} |
|
sample['obs'] = torch.zeros(obs_space) |
|
sample['action'] = torch.tensor(i % action_space) |
|
sample['done'] = False |
|
sample['next_obs'] = torch.zeros(obs_space) |
|
sample['reward'] = torch.Tensor([1.]) |
|
data.append(sample) |
|
return data |
|
|
|
|
|
@pytest.mark.parametrize('cfg', [cfg3, cfg4]) |
|
@pytest.mark.unittest |
|
def test_cql_discrete(cfg): |
|
policy = DiscreteCQLPolicy(cfg, enable_field=['collect', 'eval', 'learn']) |
|
assert type(policy._learn_model) == ArgmaxSampleWrapper |
|
assert type(policy._target_model) == TargetNetworkWrapper |
|
assert type(policy._collect_model) == EpsGreedySampleWrapper |
|
sample = get_transition_batch(bs=20) |
|
samples = policy._get_train_sample(sample) |
|
assert len(samples['obs']) == 20 |
|
state = policy._state_dict_learn() |
|
policy._load_state_dict_learn(state) |
|
sample = get_transition_discrete(size=1) |
|
out = policy._forward_learn(sample) |
|
out = policy._forward_collect(get_batch(size=8), eps=0.1) |
|
|