import pytest import torch from ding.torch_utils import is_differentiable from ding.model.template.vae import VanillaVAE @pytest.mark.unittest def test_vae(): batch_size = 32 action_shape = 6 original_action_shape = 2 obs_shape = 6 hidden_size_list = [256, 256] inputs = { 'action': torch.randn(batch_size, original_action_shape), 'obs': torch.randn(batch_size, obs_shape), 'next_obs': torch.randn(batch_size, obs_shape) } vae_model = VanillaVAE(original_action_shape, obs_shape, action_shape, hidden_size_list) outputs = vae_model(inputs) assert outputs['recons_action'].shape == (batch_size, original_action_shape) assert outputs['prediction_residual'].shape == (batch_size, obs_shape) assert isinstance(outputs['input'], dict) assert outputs['mu'].shape == (batch_size, obs_shape) assert outputs['log_var'].shape == (batch_size, obs_shape) assert outputs['z'].shape == (batch_size, action_shape) outputs_decode = vae_model.decode_with_obs(outputs['z'], inputs['obs']) assert outputs_decode['reconstruction_action'].shape == (batch_size, original_action_shape) assert outputs_decode['predition_residual'].shape == (batch_size, obs_shape) outputs['original_action'] = inputs['action'] outputs['true_residual'] = inputs['next_obs'] - inputs['obs'] vae_loss = vae_model.loss_function(outputs, kld_weight=0.01, predict_weight=0.01) is_differentiable(vae_loss['loss'], vae_model)