File size: 475 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import pytest
import torch
from ding.policy.mbpolicy.utils import q_evaluation
@pytest.mark.unittest
def test_q_evaluation():
T, B, O, A = 10, 20, 100, 30
obss = torch.randn(T, B, O)
actions = torch.randn(T, B, A)
def fake_q_fn(obss, actions):
# obss: flatten_B * O
# actions: flatten_B * A
# return: flatten_B
return obss.sum(-1)
q_value = q_evaluation(obss, actions, fake_q_fn)
assert q_value.shape == (T, B)
|