|
"""Normalization layers used in blocks |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class AdaptiveInstanceNorm2d(nn.Module): |
|
def __init__(self, num_features, eps=1e-5, momentum=0.1): |
|
super(AdaptiveInstanceNorm2d, self).__init__() |
|
self.num_features = num_features |
|
self.eps = eps |
|
self.momentum = momentum |
|
|
|
self.weight = None |
|
self.bias = None |
|
|
|
self.register_buffer("running_mean", torch.zeros(num_features)) |
|
self.register_buffer("running_var", torch.ones(num_features)) |
|
|
|
def forward(self, x): |
|
assert ( |
|
self.weight is not None and self.bias is not None |
|
), "Please assign weight and bias before calling AdaIN!" |
|
b, c = x.size(0), x.size(1) |
|
running_mean = self.running_mean.repeat(b) |
|
running_var = self.running_var.repeat(b) |
|
|
|
|
|
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) |
|
|
|
out = F.batch_norm( |
|
x_reshaped, |
|
running_mean, |
|
running_var, |
|
self.weight, |
|
self.bias, |
|
True, |
|
self.momentum, |
|
self.eps, |
|
) |
|
|
|
return out.view(b, c, *x.size()[2:]) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + "(" + str(self.num_features) + ")" |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
def __init__(self, num_features, eps=1e-5, affine=True): |
|
super(LayerNorm, self).__init__() |
|
self.num_features = num_features |
|
self.affine = affine |
|
self.eps = eps |
|
|
|
if self.affine: |
|
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) |
|
self.beta = nn.Parameter(torch.zeros(num_features)) |
|
|
|
def forward(self, x): |
|
shape = [-1] + [1] * (x.dim() - 1) |
|
|
|
if x.size(0) == 1: |
|
|
|
|
|
mean = x.view(-1).mean().view(*shape) |
|
std = x.view(-1).std().view(*shape) |
|
else: |
|
mean = x.view(x.size(0), -1).mean(1).view(*shape) |
|
std = x.view(x.size(0), -1).std(1).view(*shape) |
|
|
|
x = (x - mean) / (std + self.eps) |
|
|
|
if self.affine: |
|
shape = [1, -1] + [1] * (x.dim() - 2) |
|
x = x * self.gamma.view(*shape) + self.beta.view(*shape) |
|
return x |
|
|
|
|
|
def l2normalize(v, eps=1e-12): |
|
return v / (v.norm() + eps) |
|
|
|
|
|
class SpectralNorm(nn.Module): |
|
""" |
|
Based on the paper "Spectral Normalization for Generative Adversarial Networks" |
|
by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida and the |
|
Pytorch implementation: |
|
https://github.com/christiancosgrove/pytorch-spectral-normalization-gan |
|
""" |
|
|
|
def __init__(self, module, name="weight", power_iterations=1): |
|
super().__init__() |
|
self.module = module |
|
self.name = name |
|
self.power_iterations = power_iterations |
|
if not self._made_params(): |
|
self._make_params() |
|
|
|
def _update_u_v(self): |
|
u = getattr(self.module, self.name + "_u") |
|
v = getattr(self.module, self.name + "_v") |
|
w = getattr(self.module, self.name + "_bar") |
|
|
|
height = w.data.shape[0] |
|
for _ in range(self.power_iterations): |
|
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) |
|
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) |
|
|
|
|
|
sigma = u.dot(w.view(height, -1).mv(v)) |
|
setattr(self.module, self.name, w / sigma.expand_as(w)) |
|
|
|
def _made_params(self): |
|
try: |
|
u = getattr(self.module, self.name + "_u") |
|
v = getattr(self.module, self.name + "_v") |
|
w = getattr(self.module, self.name + "_bar") |
|
return True |
|
except AttributeError: |
|
return False |
|
|
|
def _make_params(self): |
|
w = getattr(self.module, self.name) |
|
|
|
height = w.data.shape[0] |
|
width = w.view(height, -1).data.shape[1] |
|
|
|
u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) |
|
v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) |
|
u.data = l2normalize(u.data) |
|
v.data = l2normalize(v.data) |
|
w_bar = nn.Parameter(w.data) |
|
|
|
del self.module._parameters[self.name] |
|
|
|
self.module.register_parameter(self.name + "_u", u) |
|
self.module.register_parameter(self.name + "_v", v) |
|
self.module.register_parameter(self.name + "_bar", w_bar) |
|
|
|
def forward(self, *args): |
|
self._update_u_v() |
|
return self.module.forward(*args) |
|
|
|
|
|
class SPADE(nn.Module): |
|
def __init__(self, param_free_norm_type, kernel_size, norm_nc, cond_nc): |
|
super().__init__() |
|
|
|
if param_free_norm_type == "instance": |
|
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) |
|
|
|
|
|
elif param_free_norm_type == "batch": |
|
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) |
|
else: |
|
raise ValueError( |
|
"%s is not a recognized param-free norm type in SPADE" |
|
% param_free_norm_type |
|
) |
|
|
|
|
|
nhidden = 128 |
|
|
|
pw = kernel_size // 2 |
|
self.mlp_shared = nn.Sequential( |
|
nn.Conv2d(cond_nc, nhidden, kernel_size=kernel_size, padding=pw), nn.ReLU() |
|
) |
|
self.mlp_gamma = nn.Conv2d( |
|
nhidden, norm_nc, kernel_size=kernel_size, padding=pw |
|
) |
|
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=kernel_size, padding=pw) |
|
|
|
def forward(self, x, segmap): |
|
|
|
normalized = self.param_free_norm(x) |
|
|
|
|
|
segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") |
|
actv = self.mlp_shared(segmap) |
|
gamma = self.mlp_gamma(actv) |
|
beta = self.mlp_beta(actv) |
|
|
|
out = normalized * (1 + gamma) + beta |
|
|
|
return out |
|
|