zjowowen's picture
init space
079c32c
raw
history blame
1.5 kB
import pytest
import torch
from ding.torch_utils import is_differentiable
from ding.model.template.vae import VanillaVAE
@pytest.mark.unittest
def test_vae():
batch_size = 32
action_shape = 6
original_action_shape = 2
obs_shape = 6
hidden_size_list = [256, 256]
inputs = {
'action': torch.randn(batch_size, original_action_shape),
'obs': torch.randn(batch_size, obs_shape),
'next_obs': torch.randn(batch_size, obs_shape)
}
vae_model = VanillaVAE(original_action_shape, obs_shape, action_shape, hidden_size_list)
outputs = vae_model(inputs)
assert outputs['recons_action'].shape == (batch_size, original_action_shape)
assert outputs['prediction_residual'].shape == (batch_size, obs_shape)
assert isinstance(outputs['input'], dict)
assert outputs['mu'].shape == (batch_size, obs_shape)
assert outputs['log_var'].shape == (batch_size, obs_shape)
assert outputs['z'].shape == (batch_size, action_shape)
outputs_decode = vae_model.decode_with_obs(outputs['z'], inputs['obs'])
assert outputs_decode['reconstruction_action'].shape == (batch_size, original_action_shape)
assert outputs_decode['predition_residual'].shape == (batch_size, obs_shape)
outputs['original_action'] = inputs['action']
outputs['true_residual'] = inputs['next_obs'] - inputs['obs']
vae_loss = vae_model.loss_function(outputs, kld_weight=0.01, predict_weight=0.01)
is_differentiable(vae_loss['loss'], vae_model)