|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .utils import split_feature, compute_same_pad |
|
|
|
|
|
def gaussian_p(mean, logs, x): |
|
""" |
|
lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } |
|
k = 1 (Independent) |
|
Var = logs ** 2 |
|
""" |
|
c = math.log(2 * math.pi) |
|
return -0.5 * (logs * 2.0 + ((x - mean) ** 2) / torch.exp(logs * 2.0) + c) |
|
|
|
|
|
def gaussian_likelihood(mean, logs, x): |
|
p = gaussian_p(mean, logs, x) |
|
return torch.sum(p, dim=[1, 2, 3]) |
|
|
|
|
|
def gaussian_sample(mean, logs, temperature=1): |
|
|
|
z = torch.normal(mean, torch.exp(logs) * temperature) |
|
|
|
return z |
|
|
|
|
|
def squeeze2d(input, factor): |
|
if factor == 1: |
|
return input |
|
|
|
B, C, H, W = input.size() |
|
|
|
assert H % factor == 0 and W % factor == 0, "H or W modulo factor is not 0" |
|
|
|
x = input.view(B, C, H // factor, factor, W // factor, factor) |
|
x = x.permute(0, 1, 3, 5, 2, 4).contiguous() |
|
x = x.view(B, C * factor * factor, H // factor, W // factor) |
|
|
|
return x |
|
|
|
|
|
def unsqueeze2d(input, factor): |
|
if factor == 1: |
|
return input |
|
|
|
factor2 = factor**2 |
|
|
|
B, C, H, W = input.size() |
|
|
|
assert C % (factor2) == 0, "C module factor squared is not 0" |
|
|
|
x = input.view(B, C // factor2, factor, factor, H, W) |
|
x = x.permute(0, 1, 4, 2, 5, 3).contiguous() |
|
x = x.view(B, C // (factor2), H * factor, W * factor) |
|
|
|
return x |
|
|
|
|
|
class _ActNorm(nn.Module): |
|
""" |
|
Activation Normalization |
|
Initialize the bias and scale with a given minibatch, |
|
so that the output per-channel have zero mean and unit variance for that. |
|
|
|
After initialization, `bias` and `logs` will be trained as parameters. |
|
""" |
|
|
|
def __init__(self, num_features, scale=1.0): |
|
super().__init__() |
|
|
|
size = [1, num_features, 1, 1] |
|
self.bias = nn.Parameter(torch.zeros(*size)) |
|
self.logs = nn.Parameter(torch.zeros(*size)) |
|
self.num_features = num_features |
|
self.scale = scale |
|
self.inited = False |
|
|
|
def initialize_parameters(self, input): |
|
if not self.training: |
|
raise ValueError("In Eval mode, but ActNorm not inited") |
|
|
|
with torch.no_grad(): |
|
bias = -torch.mean(input.clone(), dim=[0, 2, 3], keepdim=True) |
|
vars = torch.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) |
|
logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) |
|
|
|
self.bias.data.copy_(bias.data) |
|
self.logs.data.copy_(logs.data) |
|
|
|
self.inited = True |
|
|
|
def _center(self, input, reverse=False): |
|
if reverse: |
|
return input - self.bias |
|
else: |
|
return input + self.bias |
|
|
|
def _scale(self, input, logdet=None, reverse=False): |
|
|
|
if reverse: |
|
input = input * torch.exp(-self.logs) |
|
else: |
|
input = input * torch.exp(self.logs) |
|
|
|
if logdet is not None: |
|
""" |
|
logs is log_std of `mean of channels` |
|
so we need to multiply by number of pixels |
|
""" |
|
b, c, h, w = input.shape |
|
|
|
dlogdet = torch.sum(self.logs) * h * w |
|
|
|
if reverse: |
|
dlogdet *= -1 |
|
|
|
logdet = logdet + dlogdet |
|
|
|
return input, logdet |
|
|
|
def forward(self, input, logdet=None, reverse=False): |
|
self._check_input_dim(input) |
|
|
|
if not self.inited: |
|
self.initialize_parameters(input) |
|
|
|
if reverse: |
|
input, logdet = self._scale(input, logdet, reverse) |
|
input = self._center(input, reverse) |
|
else: |
|
input = self._center(input, reverse) |
|
input, logdet = self._scale(input, logdet, reverse) |
|
|
|
return input, logdet |
|
|
|
|
|
class ActNorm2d(_ActNorm): |
|
def __init__(self, num_features, scale=1.0): |
|
super().__init__(num_features, scale) |
|
|
|
def _check_input_dim(self, input): |
|
assert len(input.size()) == 4 |
|
assert input.size(1) == self.num_features, ( |
|
"[ActNorm]: input should be in shape as `BCHW`," |
|
" channels should be {} rather than {}".format( |
|
self.num_features, input.size() |
|
) |
|
) |
|
|
|
|
|
class LinearZeros(nn.Module): |
|
def __init__(self, in_channels, out_channels, logscale_factor=3): |
|
super().__init__() |
|
|
|
self.linear = nn.Linear(in_channels, out_channels) |
|
self.linear.weight.data.zero_() |
|
self.linear.bias.data.zero_() |
|
|
|
self.logscale_factor = logscale_factor |
|
|
|
self.logs = nn.Parameter(torch.zeros(out_channels)) |
|
|
|
def forward(self, input): |
|
output = self.linear(input) |
|
return output * torch.exp(self.logs * self.logscale_factor) |
|
|
|
|
|
class Conv2d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=(3, 3), |
|
stride=(1, 1), |
|
padding="same", |
|
do_actnorm=True, |
|
weight_std=0.05, |
|
): |
|
super().__init__() |
|
|
|
if padding == "same": |
|
padding = compute_same_pad(kernel_size, stride) |
|
elif padding == "valid": |
|
padding = 0 |
|
|
|
self.conv = nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride, |
|
padding, |
|
bias=(not do_actnorm), |
|
) |
|
|
|
|
|
self.conv.weight.data.normal_(mean=0.0, std=weight_std) |
|
|
|
if not do_actnorm: |
|
self.conv.bias.data.zero_() |
|
else: |
|
self.actnorm = ActNorm2d(out_channels) |
|
|
|
self.do_actnorm = do_actnorm |
|
|
|
def forward(self, input): |
|
x = self.conv(input) |
|
if self.do_actnorm: |
|
x, _ = self.actnorm(x) |
|
return x |
|
|
|
|
|
class Conv2dZeros(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=(3, 3), |
|
stride=(1, 1), |
|
padding="same", |
|
logscale_factor=3, |
|
): |
|
super().__init__() |
|
|
|
if padding == "same": |
|
padding = compute_same_pad(kernel_size, stride) |
|
elif padding == "valid": |
|
padding = 0 |
|
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) |
|
|
|
self.conv.weight.data.zero_() |
|
self.conv.bias.data.zero_() |
|
|
|
self.logscale_factor = logscale_factor |
|
self.logs = nn.Parameter(torch.zeros(out_channels, 1, 1)) |
|
|
|
def forward(self, input): |
|
output = self.conv(input) |
|
return output * torch.exp(self.logs * self.logscale_factor) |
|
|
|
|
|
class Permute2d(nn.Module): |
|
def __init__(self, num_channels, shuffle): |
|
super().__init__() |
|
self.num_channels = num_channels |
|
self.indices = torch.arange(self.num_channels - 1, -1, -1, dtype=torch.long) |
|
self.indices_inverse = torch.zeros((self.num_channels), dtype=torch.long) |
|
|
|
for i in range(self.num_channels): |
|
self.indices_inverse[self.indices[i]] = i |
|
|
|
if shuffle: |
|
self.reset_indices() |
|
|
|
def reset_indices(self): |
|
shuffle_idx = torch.randperm(self.indices.shape[0]) |
|
self.indices = self.indices[shuffle_idx] |
|
|
|
for i in range(self.num_channels): |
|
self.indices_inverse[self.indices[i]] = i |
|
|
|
def forward(self, input, reverse=False): |
|
assert len(input.size()) == 4 |
|
|
|
if not reverse: |
|
input = input[:, self.indices, :, :] |
|
return input |
|
else: |
|
return input[:, self.indices_inverse, :, :] |
|
|
|
|
|
class Split2d(nn.Module): |
|
def __init__(self, num_channels): |
|
super().__init__() |
|
self.conv = Conv2dZeros(num_channels // 2, num_channels) |
|
|
|
def split2d_prior(self, z): |
|
h = self.conv(z) |
|
return split_feature(h, "cross") |
|
|
|
def forward(self, input, logdet=0.0, reverse=False, temperature=None): |
|
if reverse: |
|
z1 = input |
|
mean, logs = self.split2d_prior(z1) |
|
z2 = gaussian_sample(mean, logs, temperature) |
|
z = torch.cat((z1, z2), dim=1) |
|
return z, logdet |
|
else: |
|
z1, z2 = split_feature(input, "split") |
|
mean, logs = self.split2d_prior(z1) |
|
logdet = gaussian_likelihood(mean, logs, z2) + logdet |
|
return z1, logdet |
|
|
|
|
|
class SqueezeLayer(nn.Module): |
|
def __init__(self, factor): |
|
super().__init__() |
|
self.factor = factor |
|
|
|
def forward(self, input, logdet=None, reverse=False): |
|
if reverse: |
|
output = unsqueeze2d(input, self.factor) |
|
else: |
|
output = squeeze2d(input, self.factor) |
|
|
|
return output, logdet |
|
|
|
|
|
class InvertibleConv1x1(nn.Module): |
|
def __init__(self, num_channels, LU_decomposed): |
|
super().__init__() |
|
w_shape = [num_channels, num_channels] |
|
w_init = torch.linalg.qr(torch.randn(*w_shape))[0] |
|
|
|
if not LU_decomposed: |
|
self.weight = nn.Parameter(torch.Tensor(w_init)) |
|
else: |
|
p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) |
|
s = torch.diag(upper) |
|
sign_s = torch.sign(s) |
|
log_s = torch.log(torch.abs(s)) |
|
upper = torch.triu(upper, 1) |
|
l_mask = torch.tril(torch.ones(w_shape), -1) |
|
eye = torch.eye(*w_shape) |
|
|
|
self.register_buffer("p", p) |
|
self.register_buffer("sign_s", sign_s) |
|
self.lower = nn.Parameter(lower) |
|
self.log_s = nn.Parameter(log_s) |
|
self.upper = nn.Parameter(upper) |
|
self.l_mask = l_mask |
|
self.eye = eye |
|
|
|
self.w_shape = w_shape |
|
self.LU_decomposed = LU_decomposed |
|
|
|
def get_weight(self, input, reverse): |
|
b, c, h, w = input.shape |
|
|
|
if not self.LU_decomposed: |
|
dlogdet = torch.slogdet(self.weight)[1] * h * w |
|
if reverse: |
|
weight = torch.inverse(self.weight) |
|
else: |
|
weight = self.weight |
|
else: |
|
self.l_mask = self.l_mask.to(input.device) |
|
self.eye = self.eye.to(input.device) |
|
|
|
lower = self.lower * self.l_mask + self.eye |
|
|
|
u = self.upper * self.l_mask.transpose(0, 1).contiguous() |
|
u += torch.diag(self.sign_s * torch.exp(self.log_s)) |
|
|
|
dlogdet = torch.sum(self.log_s) * h * w |
|
|
|
if reverse: |
|
u_inv = torch.inverse(u) |
|
l_inv = torch.inverse(lower) |
|
p_inv = torch.inverse(self.p) |
|
|
|
weight = torch.matmul(u_inv, torch.matmul(l_inv, p_inv)) |
|
else: |
|
weight = torch.matmul(self.p, torch.matmul(lower, u)) |
|
|
|
return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet |
|
|
|
def forward(self, input, logdet=None, reverse=False): |
|
""" |
|
log-det = log|abs(|W|)| * pixels |
|
""" |
|
weight, dlogdet = self.get_weight(input, reverse) |
|
|
|
if not reverse: |
|
z = F.conv2d(input, weight) |
|
if logdet is not None: |
|
logdet = logdet + dlogdet |
|
return z, logdet |
|
else: |
|
z = F.conv2d(input, weight) |
|
if logdet is not None: |
|
logdet = logdet - dlogdet |
|
return z, logdet |
|
|