|
import pytest |
|
import numpy as np |
|
import random |
|
import torch |
|
from ding.data.level_replay.level_sampler import LevelSampler |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_level_sampler(): |
|
num_seeds = 500 |
|
obs_shape = [3, 64, 64] |
|
action_shape = 15 |
|
collector_env_num = 16 |
|
level_replay_dict = dict( |
|
strategy='min_margin', |
|
score_transform='rank', |
|
temperature=0.1, |
|
) |
|
N = 10 |
|
collector_sample_length = 160 |
|
|
|
train_seeds = [i for i in range(num_seeds)] |
|
level_sampler = LevelSampler(train_seeds, obs_shape, action_shape, collector_env_num, level_replay_dict) |
|
|
|
value = torch.randn(collector_sample_length) |
|
reward = torch.randn(collector_sample_length) |
|
adv = torch.randn(collector_sample_length) |
|
done = torch.randn(collector_sample_length) |
|
logit = torch.randn(collector_sample_length, N) |
|
seeds = [random.randint(0, num_seeds) for i in range(collector_env_num)] |
|
all_seeds = torch.Tensor( |
|
[seeds[i] for i in range(collector_env_num) for j in range(int(collector_sample_length / collector_env_num))] |
|
) |
|
|
|
train_data = {'value': value, 'reward': reward, 'adv': adv, 'done': done, 'logit': logit, 'seed': all_seeds} |
|
level_sampler.update_with_rollouts(train_data, collector_env_num) |
|
sample_seed = level_sampler.sample() |
|
assert isinstance(sample_seed, int) |
|
|