File size: 1,339 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import pytest
import numpy as np
import random
import torch
from ding.data.level_replay.level_sampler import LevelSampler


@pytest.mark.unittest
def test_level_sampler():
    num_seeds = 500
    obs_shape = [3, 64, 64]
    action_shape = 15
    collector_env_num = 16
    level_replay_dict = dict(
        strategy='min_margin',
        score_transform='rank',
        temperature=0.1,
    )
    N = 10
    collector_sample_length = 160

    train_seeds = [i for i in range(num_seeds)]
    level_sampler = LevelSampler(train_seeds, obs_shape, action_shape, collector_env_num, level_replay_dict)

    value = torch.randn(collector_sample_length)
    reward = torch.randn(collector_sample_length)
    adv = torch.randn(collector_sample_length)
    done = torch.randn(collector_sample_length)
    logit = torch.randn(collector_sample_length, N)
    seeds = [random.randint(0, num_seeds) for i in range(collector_env_num)]
    all_seeds = torch.Tensor(
        [seeds[i] for i in range(collector_env_num) for j in range(int(collector_sample_length / collector_env_num))]
    )

    train_data = {'value': value, 'reward': reward, 'adv': adv, 'done': done, 'logit': logit, 'seed': all_seeds}
    level_sampler.update_with_rollouts(train_data, collector_env_num)
    sample_seed = level_sampler.sample()
    assert isinstance(sample_seed, int)