gomoku / DI-engine /ding /policy /mbpolicy /tests /test_mbpolicy_utils.py
zjowowen's picture
init space
079c32c
raw
history blame
475 Bytes
import pytest
import torch
from ding.policy.mbpolicy.utils import q_evaluation
@pytest.mark.unittest
def test_q_evaluation():
T, B, O, A = 10, 20, 100, 30
obss = torch.randn(T, B, O)
actions = torch.randn(T, B, A)
def fake_q_fn(obss, actions):
# obss: flatten_B * O
# actions: flatten_B * A
# return: flatten_B
return obss.sum(-1)
q_value = q_evaluation(obss, actions, fake_q_fn)
assert q_value.shape == (T, B)