zjowowen's picture
init space
079c32c
raw
history blame contribute delete
973 Bytes
import pytest
import torch
from ding.rl_utils import gae_data, gae
@pytest.mark.unittest
def test_gae():
# batch trajectory case
T, B = 32, 4
value = torch.randn(T, B)
next_value = torch.randn(T, B)
reward = torch.randn(T, B)
done = torch.zeros((T, B))
data = gae_data(value, next_value, reward, done, None)
adv = gae(data)
assert adv.shape == (T, B)
# single trajectory case/concat trajectory case
T = 24
value = torch.randn(T)
next_value = torch.randn(T)
reward = torch.randn(T)
done = torch.zeros((T))
data = gae_data(value, next_value, reward, done, None)
adv = gae(data)
assert adv.shape == (T, )
def test_gae_multi_agent():
T, B, A = 32, 4, 8
value = torch.randn(T, B, A)
next_value = torch.randn(T, B, A)
reward = torch.randn(T, B)
done = torch.zeros(T, B)
data = gae_data(value, next_value, reward, done, None)
adv = gae(data)
assert adv.shape == (T, B, A)