File size: 611 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import pytest
import torch
from ding.rl_utils import compute_q_retraces
@pytest.mark.unittest
def test_compute_q_retraces():
T, B, N = 64, 32, 6
q_values = torch.randn(T + 1, B, N)
v_pred = torch.randn(T + 1, B, 1)
rewards = torch.randn(T, B)
ratio = torch.rand(T, B, N) * 0.4 + 0.8
assert ratio.max() <= 1.2 and ratio.min() >= 0.8
weights = torch.rand(T, B)
actions = torch.randint(0, N, size=(T, B))
with torch.no_grad():
q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio, gamma=0.99)
assert q_retraces.shape == (T + 1, B, 1)
|