import pytest | |
import torch | |
from ding.policy.mbpolicy.utils import q_evaluation | |
def test_q_evaluation(): | |
T, B, O, A = 10, 20, 100, 30 | |
obss = torch.randn(T, B, O) | |
actions = torch.randn(T, B, A) | |
def fake_q_fn(obss, actions): | |
# obss: flatten_B * O | |
# actions: flatten_B * A | |
# return: flatten_B | |
return obss.sum(-1) | |
q_value = q_evaluation(obss, actions, fake_q_fn) | |
assert q_value.shape == (T, B) | |