File size: 973 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import pytest
import torch
from ding.rl_utils import gae_data, gae


@pytest.mark.unittest
def test_gae():
    # batch trajectory case
    T, B = 32, 4
    value = torch.randn(T, B)
    next_value = torch.randn(T, B)
    reward = torch.randn(T, B)
    done = torch.zeros((T, B))
    data = gae_data(value, next_value, reward, done, None)
    adv = gae(data)
    assert adv.shape == (T, B)
    # single trajectory case/concat trajectory case
    T = 24
    value = torch.randn(T)
    next_value = torch.randn(T)
    reward = torch.randn(T)
    done = torch.zeros((T))
    data = gae_data(value, next_value, reward, done, None)
    adv = gae(data)
    assert adv.shape == (T, )


def test_gae_multi_agent():
    T, B, A = 32, 4, 8
    value = torch.randn(T, B, A)
    next_value = torch.randn(T, B, A)
    reward = torch.randn(T, B)
    done = torch.zeros(T, B)
    data = gae_data(value, next_value, reward, done, None)
    adv = gae(data)
    assert adv.shape == (T, B, A)