File size: 1,313 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import pytest
import torch
from ding.rl_utils.upgo import upgo_loss, upgo_returns, tb_cross_entropy


@pytest.mark.unittest
def test_upgo():
    T, B, N, N2 = 4, 8, 5, 7

    # tb_cross_entropy: 3 tests
    logit = torch.randn(T, B, N, N2).softmax(-1).requires_grad_(True)
    action = logit.argmax(-1).detach()
    ce = tb_cross_entropy(logit, action)
    assert ce.shape == (T, B)

    logit = torch.randn(T, B, N, N2, 2).softmax(-1).requires_grad_(True)
    action = logit.argmax(-1).detach()
    with pytest.raises(AssertionError):
        ce = tb_cross_entropy(logit, action)

    logit = torch.randn(T, B, N).softmax(-1).requires_grad_(True)
    action = logit.argmax(-1).detach()
    ce = tb_cross_entropy(logit, action)
    assert ce.shape == (T, B)

    # upgo_returns
    rewards = torch.randn(T, B)
    bootstrap_values = torch.randn(T + 1, B).requires_grad_(True)
    returns = upgo_returns(rewards, bootstrap_values)
    assert returns.shape == (T, B)

    # upgo loss
    rhos = torch.randn(T, B)
    loss = upgo_loss(logit, rhos, action, rewards, bootstrap_values)
    assert logit.requires_grad
    assert bootstrap_values.requires_grad
    for t in [logit, bootstrap_values]:
        assert t.grad is None
    loss.backward()
    for t in [logit]:
        assert isinstance(t.grad, torch.Tensor)