jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
raw
history blame
1.52 kB
import os
import pytest
import torch
from mmcv import Config
from risk_biased.models.latent_distributions import GaussianLatentDistribution
@pytest.fixture(scope="module")
def params():
torch.manual_seed(0)
working_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(
working_dir, "..", "..", "..", "risk_biased", "config", "learning_config.py"
)
cfg = Config.fromfile(config_path)
cfg.batch_size = 4
cfg.latent_dim = 2
return cfg
@pytest.mark.parametrize("threshold", [(1e-5), (10.0)])
def test_get_kl_loss(params, threshold: float):
z_mean_log_std = torch.rand(params.batch_size, 1, params.latent_dim*2)
distribution = GaussianLatentDistribution(z_mean_log_std)
z_mean, z_log_var = torch.split(z_mean_log_std, params.latent_dim, dim=-1)
z_log_std = z_log_var / 2.0
kl_target = (
(-0.5 * (1.0 + 2.0 * z_log_std - z_mean.square() - (2 * z_log_std).exp()))
.clamp_min(threshold)
).mean()
prior_z_mean_log_std = torch.zeros(params.latent_dim*2)
prior_distribution = GaussianLatentDistribution(prior_z_mean_log_std)
# Test kl loss is 0 on identical distributions
assert torch.isclose(
distribution.kl_loss(distribution, threshold=threshold),
torch.zeros(1), atol=threshold
)
# test kl loss when prior is unit Gaussian
assert torch.isclose(
distribution.kl_loss(prior_distribution, threshold),
kl_target,
)