zjowowen's picture
init space
079c32c
raw
history blame
1.31 kB
import pytest
import torch
from ding.rl_utils.upgo import upgo_loss, upgo_returns, tb_cross_entropy
@pytest.mark.unittest
def test_upgo():
T, B, N, N2 = 4, 8, 5, 7
# tb_cross_entropy: 3 tests
logit = torch.randn(T, B, N, N2).softmax(-1).requires_grad_(True)
action = logit.argmax(-1).detach()
ce = tb_cross_entropy(logit, action)
assert ce.shape == (T, B)
logit = torch.randn(T, B, N, N2, 2).softmax(-1).requires_grad_(True)
action = logit.argmax(-1).detach()
with pytest.raises(AssertionError):
ce = tb_cross_entropy(logit, action)
logit = torch.randn(T, B, N).softmax(-1).requires_grad_(True)
action = logit.argmax(-1).detach()
ce = tb_cross_entropy(logit, action)
assert ce.shape == (T, B)
# upgo_returns
rewards = torch.randn(T, B)
bootstrap_values = torch.randn(T + 1, B).requires_grad_(True)
returns = upgo_returns(rewards, bootstrap_values)
assert returns.shape == (T, B)
# upgo loss
rhos = torch.randn(T, B)
loss = upgo_loss(logit, rhos, action, rewards, bootstrap_values)
assert logit.requires_grad
assert bootstrap_values.requires_grad
for t in [logit, bootstrap_values]:
assert t.grad is None
loss.backward()
for t in [logit]:
assert isinstance(t.grad, torch.Tensor)