gomoku / DI-engine /ding /data /level_replay /tests /test_level_sampler.py
zjowowen's picture
init space
079c32c
raw
history blame
1.34 kB
import pytest
import numpy as np
import random
import torch
from ding.data.level_replay.level_sampler import LevelSampler
@pytest.mark.unittest
def test_level_sampler():
num_seeds = 500
obs_shape = [3, 64, 64]
action_shape = 15
collector_env_num = 16
level_replay_dict = dict(
strategy='min_margin',
score_transform='rank',
temperature=0.1,
)
N = 10
collector_sample_length = 160
train_seeds = [i for i in range(num_seeds)]
level_sampler = LevelSampler(train_seeds, obs_shape, action_shape, collector_env_num, level_replay_dict)
value = torch.randn(collector_sample_length)
reward = torch.randn(collector_sample_length)
adv = torch.randn(collector_sample_length)
done = torch.randn(collector_sample_length)
logit = torch.randn(collector_sample_length, N)
seeds = [random.randint(0, num_seeds) for i in range(collector_env_num)]
all_seeds = torch.Tensor(
[seeds[i] for i in range(collector_env_num) for j in range(int(collector_sample_length / collector_env_num))]
)
train_data = {'value': value, 'reward': reward, 'adv': adv, 'done': done, 'logit': logit, 'seed': all_seeds}
level_sampler.update_with_rollouts(train_data, collector_env_num)
sample_seed = level_sampler.sample()
assert isinstance(sample_seed, int)