File size: 973 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import pytest
import torch
from ding.rl_utils import gae_data, gae
@pytest.mark.unittest
def test_gae():
# batch trajectory case
T, B = 32, 4
value = torch.randn(T, B)
next_value = torch.randn(T, B)
reward = torch.randn(T, B)
done = torch.zeros((T, B))
data = gae_data(value, next_value, reward, done, None)
adv = gae(data)
assert adv.shape == (T, B)
# single trajectory case/concat trajectory case
T = 24
value = torch.randn(T)
next_value = torch.randn(T)
reward = torch.randn(T)
done = torch.zeros((T))
data = gae_data(value, next_value, reward, done, None)
adv = gae(data)
assert adv.shape == (T, )
def test_gae_multi_agent():
T, B, A = 32, 4, 8
value = torch.randn(T, B, A)
next_value = torch.randn(T, B, A)
reward = torch.randn(T, B)
done = torch.zeros(T, B)
data = gae_data(value, next_value, reward, done, None)
adv = gae(data)
assert adv.shape == (T, B, A)
|