|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import torch.nn.init as init |
|
|
|
from .modules import InvertibleConv1x1 |
|
|
|
|
|
def initialize_weights(net_l, scale=1): |
|
if not isinstance(net_l, list): |
|
net_l = [net_l] |
|
for net in net_l: |
|
for m in net.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
init.kaiming_normal_(m.weight, a=0, mode="fan_in") |
|
m.weight.data *= scale |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.Linear): |
|
init.kaiming_normal_(m.weight, a=0, mode="fan_in") |
|
m.weight.data *= scale |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.BatchNorm2d): |
|
init.constant_(m.weight, 1) |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
|
|
def initialize_weights_xavier(net_l, scale=1): |
|
if not isinstance(net_l, list): |
|
net_l = [net_l] |
|
for net in net_l: |
|
for m in net.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
init.xavier_normal_(m.weight) |
|
m.weight.data *= scale |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.Linear): |
|
init.xavier_normal_(m.weight) |
|
m.weight.data *= scale |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.BatchNorm2d): |
|
init.constant_(m.weight, 1) |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
|
|
class DenseBlock(nn.Module): |
|
def __init__(self, channel_in, channel_out, init="xavier", gc=32, bias=True): |
|
super(DenseBlock, self).__init__() |
|
self.conv1 = nn.Conv2d(channel_in, gc, 3, 1, 1, bias=bias) |
|
self.conv2 = nn.Conv2d(channel_in + gc, gc, 3, 1, 1, bias=bias) |
|
self.conv3 = nn.Conv2d(channel_in + 2 * gc, gc, 3, 1, 1, bias=bias) |
|
self.conv4 = nn.Conv2d(channel_in + 3 * gc, gc, 3, 1, 1, bias=bias) |
|
self.conv5 = nn.Conv2d(channel_in + 4 * gc, channel_out, 3, 1, 1, bias=bias) |
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
|
if init == "xavier": |
|
initialize_weights_xavier( |
|
[self.conv1, self.conv2, self.conv3, self.conv4], 0.1 |
|
) |
|
else: |
|
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4], 0.1) |
|
initialize_weights(self.conv5, 0) |
|
|
|
def forward(self, x): |
|
x1 = self.lrelu(self.conv1(x)) |
|
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) |
|
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) |
|
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) |
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) |
|
|
|
return x5 |
|
|
|
|
|
def subnet(net_structure, init="xavier"): |
|
def constructor(channel_in, channel_out): |
|
if net_structure == "DBNet": |
|
if init == "xavier": |
|
return DenseBlock(channel_in, channel_out, init) |
|
else: |
|
return DenseBlock(channel_in, channel_out) |
|
|
|
else: |
|
return None |
|
|
|
return constructor |
|
|
|
|
|
class InvBlock(nn.Module): |
|
def __init__(self, subnet_constructor, channel_num, channel_split_num, clamp=0.8): |
|
super(InvBlock, self).__init__() |
|
|
|
|
|
|
|
self.split_len1 = channel_split_num |
|
self.split_len2 = channel_num - channel_split_num |
|
|
|
self.clamp = clamp |
|
|
|
self.F = subnet_constructor(self.split_len2, self.split_len1) |
|
self.G = subnet_constructor(self.split_len1, self.split_len2) |
|
self.H = subnet_constructor(self.split_len1, self.split_len2) |
|
|
|
in_channels = 3 |
|
self.invconv = InvertibleConv1x1(in_channels, LU_decomposed=True) |
|
self.flow_permutation = lambda z, logdet, rev: self.invconv(z, logdet, rev) |
|
|
|
def forward(self, x, rev=False): |
|
if not rev: |
|
|
|
x, logdet = self.flow_permutation(x, logdet=0, rev=False) |
|
|
|
|
|
x1, x2 = ( |
|
x.narrow(1, 0, self.split_len1), |
|
x.narrow(1, self.split_len1, self.split_len2), |
|
) |
|
|
|
y1 = x1 + self.F(x2) |
|
self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1) |
|
y2 = x2.mul(torch.exp(self.s)) + self.G(y1) |
|
out = torch.cat((y1, y2), 1) |
|
else: |
|
|
|
x1, x2 = ( |
|
x.narrow(1, 0, self.split_len1), |
|
x.narrow(1, self.split_len1, self.split_len2), |
|
) |
|
self.s = self.clamp * (torch.sigmoid(self.H(x1)) * 2 - 1) |
|
y2 = (x2 - self.G(x1)).div(torch.exp(self.s)) |
|
y1 = x1 - self.F(y2) |
|
|
|
x = torch.cat((y1, y2), 1) |
|
|
|
|
|
out, logdet = self.flow_permutation(x, logdet=0, rev=True) |
|
|
|
return out |
|
|
|
|
|
class InvISPNet(nn.Module): |
|
def __init__( |
|
self, |
|
channel_in=3, |
|
channel_out=3, |
|
subnet_constructor=subnet("DBNet"), |
|
block_num=8, |
|
): |
|
super(InvISPNet, self).__init__() |
|
operations = [] |
|
|
|
current_channel = channel_in |
|
channel_num = channel_in |
|
channel_split_num = 1 |
|
|
|
for j in range(block_num): |
|
b = InvBlock( |
|
subnet_constructor, channel_num, channel_split_num |
|
) |
|
operations.append(b) |
|
|
|
self.operations = nn.ModuleList(operations) |
|
|
|
self.initialize() |
|
|
|
def initialize(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
init.xavier_normal_(m.weight) |
|
m.weight.data *= 1.0 |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.Linear): |
|
init.xavier_normal_(m.weight) |
|
m.weight.data *= 1.0 |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.BatchNorm2d): |
|
init.constant_(m.weight, 1) |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
def forward(self, x, rev=False): |
|
out = x |
|
|
|
if not rev: |
|
for op in self.operations: |
|
out = op.forward(out, rev) |
|
else: |
|
for op in reversed(self.operations): |
|
out = op.forward(out, rev) |
|
|
|
return out |
|
|