File size: 3,004 Bytes
c4c7cee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Generate multivariate von Mises Fisher samples.
PyTorch implementation of the original code from:
https://github.com/clara-labs/spherecluster
"""

import torch

__all__ = ["sample_vMF"]


def vMF_sampler(
    net,
    batch,
):
    mu, kappa = net(batch)
    return sample_vMF(mu.T, kappa.squeeze(1))


def vMF_mixture_sampler(
    net,
    batch,
):
    mu_mixture, kappa_mixture, weights = net(batch)
    # Sample mixture component indices based on weights
    indices = torch.multinomial(weights, num_samples=1).squeeze()
    # Select corresponding mu and kappa
    mu = mu_mixture[torch.arange(mu_mixture.shape[0]), indices]
    kappa = kappa_mixture[torch.arange(kappa_mixture.shape[0]), indices]
    return sample_vMF(mu.T, kappa)


def sample_vMF(mu, kappa, num_samples=1):
    """Generate N-dimensional samples from von Mises Fisher
    distribution around center mu ∈ R^N with concentration kappa.
    mu and kappa may be vectors,
    mu should have shape (N,) or (N, 1), kappa should be scalar or vector of length N.
    """
    if len(mu.shape) == 1:
        mu = mu.unsqueeze(1)

    if isinstance(kappa, torch.Tensor):
        dim = mu.shape[0]
        assert mu.shape[1] == kappa.size(0)
    else:
        dim = mu.shape[0]
        mu = mu.repeat(1, num_samples)
        kappa = torch.full((num_samples,), kappa, device=mu.device, dtype=mu.dtype)

    # sample offset from center (on sphere) with spread kappa
    w = _sample_weight(kappa, dim)

    # sample a point v on the unit sphere that's orthogonal to mu
    v = _sample_orthonormal_to(mu)

    # compute new point
    result = v * torch.sqrt(1.0 - w**2).unsqueeze(0) + w.unsqueeze(0) * mu
    return result.T


def _sample_weight(kappa, dim):
    """Rejection sampling scheme for sampling distance from center on
    surface of the sphere.
    """
    dim = dim - 1  # since S^{n-1}
    try:
        size = kappa.size(0)
    except AttributeError:
        size = 1

    b = dim / (torch.sqrt(4.0 * kappa**2 + dim**2) + 2 * kappa)
    x = (1.0 - b) / (1.0 + b)
    c = kappa * x + dim * torch.log(1 - x**2)

    w = torch.zeros_like(kappa)
    idx = torch.zeros_like(kappa, dtype=torch.bool)

    while True:
        where_zero = ~idx
        if torch.all(idx):
            return w

        z = (
            torch.distributions.Beta(dim / 2.0, dim / 2.0)
            .sample((size,))
            .to(kappa.device)
        )
        _w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z)
        u = torch.rand(size, device=kappa.device)

        _idx = kappa * _w + dim * torch.log(1.0 - x * _w) - c >= torch.log(u)

        if not torch.any(_idx):
            continue

        w[where_zero] = _w[where_zero]
        idx[_idx] = True


def _sample_orthonormal_to(mu):
    """Sample point on sphere orthogonal to mu."""
    v = torch.randn(mu.shape[0], mu.shape[1], device=mu.device)
    proj_mu_v = mu * ((v * mu).sum(dim=0)) / torch.norm(mu, dim=0) ** 2
    orthto = v - proj_mu_v
    return orthto / torch.norm(orthto, dim=0)