File size: 1,829 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
import numpy as np
import pytest
from itertools import product
from ding.model.template import PG
from ding.torch_utils import is_differentiable
from ding.utils import squeeze
B = 4
@pytest.mark.unittest
class TestDiscretePG:
def output_check(self, model, outputs):
if isinstance(outputs, torch.Tensor):
loss = outputs.sum()
elif isinstance(outputs, list):
loss = sum([t.sum() for t in outputs])
elif isinstance(outputs, dict):
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)
def test_discrete_pg(self):
obs_shape = (4, 84, 84)
action_shape = 5
model = PG(
obs_shape,
action_shape,
)
inputs = torch.randn(B, 4, 84, 84)
outputs = model(inputs)
assert isinstance(outputs, dict)
assert outputs['logit'].shape == (B, action_shape)
assert outputs['dist'].sample().shape == (B, )
self.output_check(model, outputs['logit'])
def test_continuous_pg(self):
N = 32
action_shape = (6, )
inputs = {'obs': torch.randn(B, N), 'action': torch.randn(B, squeeze(action_shape))}
model = PG(
obs_shape=(N, ),
action_shape=action_shape,
action_space='continuous',
)
# compute_action
print(model)
outputs = model(inputs['obs'])
assert isinstance(outputs, dict)
dist = outputs['dist']
action = dist.sample()
assert action.shape == (B, *action_shape)
logit = outputs['logit']
mu, sigma = logit['mu'], logit['sigma']
assert mu.shape == (B, *action_shape)
assert sigma.shape == (B, *action_shape)
is_differentiable(mu.sum() + sigma.sum(), model)
|