APISR / loss /anime_perceptual_loss.py
HikariDawn's picture
feat: initial push
561c629
import os, sys
from collections import OrderedDict
import cv2
import torch.nn as nn
import torch
from torchvision import models
import torchvision.transforms as transforms
'''
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
Conv2d-5 [-1, 64, 56, 56] 4,096
BatchNorm2d-6 [-1, 64, 56, 56] 128
ReLU-7 [-1, 64, 56, 56] 0
Conv2d-8 [-1, 64, 56, 56] 36,864
BatchNorm2d-9 [-1, 64, 56, 56] 128
ReLU-10 [-1, 64, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 16,384
BatchNorm2d-12 [-1, 256, 56, 56] 512
Conv2d-13 [-1, 256, 56, 56] 16,384
BatchNorm2d-14 [-1, 256, 56, 56] 512
ReLU-15 [-1, 256, 56, 56] 0
Bottleneck-16 [-1, 256, 56, 56] 0
Conv2d-17 [-1, 64, 56, 56] 16,384
BatchNorm2d-18 [-1, 64, 56, 56] 128
ReLU-19 [-1, 64, 56, 56] 0
Conv2d-20 [-1, 64, 56, 56] 36,864
BatchNorm2d-21 [-1, 64, 56, 56] 128
ReLU-22 [-1, 64, 56, 56] 0
Conv2d-23 [-1, 256, 56, 56] 16,384
BatchNorm2d-24 [-1, 256, 56, 56] 512
ReLU-25 [-1, 256, 56, 56] 0
Bottleneck-26 [-1, 256, 56, 56] 0
Conv2d-27 [-1, 64, 56, 56] 16,384
BatchNorm2d-28 [-1, 64, 56, 56] 128
ReLU-29 [-1, 64, 56, 56] 0
Conv2d-30 [-1, 64, 56, 56] 36,864
BatchNorm2d-31 [-1, 64, 56, 56] 128
ReLU-32 [-1, 64, 56, 56] 0
Conv2d-33 [-1, 256, 56, 56] 16,384
BatchNorm2d-34 [-1, 256, 56, 56] 512
ReLU-35 [-1, 256, 56, 56] 0
Bottleneck-36 [-1, 256, 56, 56] 0
Conv2d-37 [-1, 128, 56, 56] 32,768
BatchNorm2d-38 [-1, 128, 56, 56] 256
ReLU-39 [-1, 128, 56, 56] 0
Conv2d-40 [-1, 128, 28, 28] 147,456
BatchNorm2d-41 [-1, 128, 28, 28] 256
ReLU-42 [-1, 128, 28, 28] 0
Conv2d-43 [-1, 512, 28, 28] 65,536
BatchNorm2d-44 [-1, 512, 28, 28] 1,024
Conv2d-45 [-1, 512, 28, 28] 131,072
BatchNorm2d-46 [-1, 512, 28, 28] 1,024
ReLU-47 [-1, 512, 28, 28] 0
Bottleneck-48 [-1, 512, 28, 28] 0
Conv2d-49 [-1, 128, 28, 28] 65,536
BatchNorm2d-50 [-1, 128, 28, 28] 256
ReLU-51 [-1, 128, 28, 28] 0
Conv2d-52 [-1, 128, 28, 28] 147,456
BatchNorm2d-53 [-1, 128, 28, 28] 256
ReLU-54 [-1, 128, 28, 28] 0
Conv2d-55 [-1, 512, 28, 28] 65,536
BatchNorm2d-56 [-1, 512, 28, 28] 1,024
ReLU-57 [-1, 512, 28, 28] 0
Bottleneck-58 [-1, 512, 28, 28] 0
Conv2d-59 [-1, 128, 28, 28] 65,536
BatchNorm2d-60 [-1, 128, 28, 28] 256
ReLU-61 [-1, 128, 28, 28] 0
Conv2d-62 [-1, 128, 28, 28] 147,456
BatchNorm2d-63 [-1, 128, 28, 28] 256
ReLU-64 [-1, 128, 28, 28] 0
Conv2d-65 [-1, 512, 28, 28] 65,536
BatchNorm2d-66 [-1, 512, 28, 28] 1,024
ReLU-67 [-1, 512, 28, 28] 0
Bottleneck-68 [-1, 512, 28, 28] 0
Conv2d-69 [-1, 128, 28, 28] 65,536
BatchNorm2d-70 [-1, 128, 28, 28] 256
ReLU-71 [-1, 128, 28, 28] 0
Conv2d-72 [-1, 128, 28, 28] 147,456
BatchNorm2d-73 [-1, 128, 28, 28] 256
ReLU-74 [-1, 128, 28, 28] 0
Conv2d-75 [-1, 512, 28, 28] 65,536
BatchNorm2d-76 [-1, 512, 28, 28] 1,024
ReLU-77 [-1, 512, 28, 28] 0
Bottleneck-78 [-1, 512, 28, 28] 0
Conv2d-79 [-1, 256, 28, 28] 131,072
BatchNorm2d-80 [-1, 256, 28, 28] 512
ReLU-81 [-1, 256, 28, 28] 0
Conv2d-82 [-1, 256, 14, 14] 589,824
BatchNorm2d-83 [-1, 256, 14, 14] 512
ReLU-84 [-1, 256, 14, 14] 0
Conv2d-85 [-1, 1024, 14, 14] 262,144
BatchNorm2d-86 [-1, 1024, 14, 14] 2,048
Conv2d-87 [-1, 1024, 14, 14] 524,288
BatchNorm2d-88 [-1, 1024, 14, 14] 2,048
ReLU-89 [-1, 1024, 14, 14] 0
Bottleneck-90 [-1, 1024, 14, 14] 0
Conv2d-91 [-1, 256, 14, 14] 262,144
BatchNorm2d-92 [-1, 256, 14, 14] 512
ReLU-93 [-1, 256, 14, 14] 0
Conv2d-94 [-1, 256, 14, 14] 589,824
BatchNorm2d-95 [-1, 256, 14, 14] 512
ReLU-96 [-1, 256, 14, 14] 0
Conv2d-97 [-1, 1024, 14, 14] 262,144
BatchNorm2d-98 [-1, 1024, 14, 14] 2,048
ReLU-99 [-1, 1024, 14, 14] 0
Bottleneck-100 [-1, 1024, 14, 14] 0
Conv2d-101 [-1, 256, 14, 14] 262,144
BatchNorm2d-102 [-1, 256, 14, 14] 512
ReLU-103 [-1, 256, 14, 14] 0
Conv2d-104 [-1, 256, 14, 14] 589,824
BatchNorm2d-105 [-1, 256, 14, 14] 512
ReLU-106 [-1, 256, 14, 14] 0
Conv2d-107 [-1, 1024, 14, 14] 262,144
BatchNorm2d-108 [-1, 1024, 14, 14] 2,048
ReLU-109 [-1, 1024, 14, 14] 0
Bottleneck-110 [-1, 1024, 14, 14] 0
Conv2d-111 [-1, 256, 14, 14] 262,144
BatchNorm2d-112 [-1, 256, 14, 14] 512
ReLU-113 [-1, 256, 14, 14] 0
Conv2d-114 [-1, 256, 14, 14] 589,824
BatchNorm2d-115 [-1, 256, 14, 14] 512
ReLU-116 [-1, 256, 14, 14] 0
Conv2d-117 [-1, 1024, 14, 14] 262,144
BatchNorm2d-118 [-1, 1024, 14, 14] 2,048
ReLU-119 [-1, 1024, 14, 14] 0
Bottleneck-120 [-1, 1024, 14, 14] 0
Conv2d-121 [-1, 256, 14, 14] 262,144
BatchNorm2d-122 [-1, 256, 14, 14] 512
ReLU-123 [-1, 256, 14, 14] 0
Conv2d-124 [-1, 256, 14, 14] 589,824
BatchNorm2d-125 [-1, 256, 14, 14] 512
ReLU-126 [-1, 256, 14, 14] 0
Conv2d-127 [-1, 1024, 14, 14] 262,144
BatchNorm2d-128 [-1, 1024, 14, 14] 2,048
ReLU-129 [-1, 1024, 14, 14] 0
Bottleneck-130 [-1, 1024, 14, 14] 0
Conv2d-131 [-1, 256, 14, 14] 262,144
BatchNorm2d-132 [-1, 256, 14, 14] 512
ReLU-133 [-1, 256, 14, 14] 0
Conv2d-134 [-1, 256, 14, 14] 589,824
BatchNorm2d-135 [-1, 256, 14, 14] 512
ReLU-136 [-1, 256, 14, 14] 0
Conv2d-137 [-1, 1024, 14, 14] 262,144
BatchNorm2d-138 [-1, 1024, 14, 14] 2,048
ReLU-139 [-1, 1024, 14, 14] 0
Bottleneck-140 [-1, 1024, 14, 14] 0
Conv2d-141 [-1, 512, 14, 14] 524,288
BatchNorm2d-142 [-1, 512, 14, 14] 1,024
ReLU-143 [-1, 512, 14, 14] 0
Conv2d-144 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-145 [-1, 512, 7, 7] 1,024
ReLU-146 [-1, 512, 7, 7] 0
Conv2d-147 [-1, 2048, 7, 7] 1,048,576
BatchNorm2d-148 [-1, 2048, 7, 7] 4,096
Conv2d-149 [-1, 2048, 7, 7] 2,097,152
BatchNorm2d-150 [-1, 2048, 7, 7] 4,096
ReLU-151 [-1, 2048, 7, 7] 0
Bottleneck-152 [-1, 2048, 7, 7] 0
Conv2d-153 [-1, 512, 7, 7] 1,048,576
BatchNorm2d-154 [-1, 512, 7, 7] 1,024
ReLU-155 [-1, 512, 7, 7] 0
Conv2d-156 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-157 [-1, 512, 7, 7] 1,024
ReLU-158 [-1, 512, 7, 7] 0
Conv2d-159 [-1, 2048, 7, 7] 1,048,576
BatchNorm2d-160 [-1, 2048, 7, 7] 4,096
ReLU-161 [-1, 2048, 7, 7] 0
Bottleneck-162 [-1, 2048, 7, 7] 0
Conv2d-163 [-1, 512, 7, 7] 1,048,576
BatchNorm2d-164 [-1, 512, 7, 7] 1,024
ReLU-165 [-1, 512, 7, 7] 0
Conv2d-166 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-167 [-1, 512, 7, 7] 1,024
ReLU-168 [-1, 512, 7, 7] 0
Conv2d-169 [-1, 2048, 7, 7] 1,048,576
BatchNorm2d-170 [-1, 2048, 7, 7] 4,096
ReLU-171 [-1, 2048, 7, 7] 0
Bottleneck-172 [-1, 2048, 7, 7] 0
AdaptiveMaxPool2d-173 [-1, 2048, 1, 1] 0
AdaptiveAvgPool2d-174 [-1, 2048, 1, 1] 0
AdaptiveConcatPool2d-175 [-1, 4096, 1, 1] 0
Flatten-176 [-1, 4096] 0
BatchNorm1d-177 [-1, 4096] 8,192
Dropout-178 [-1, 4096] 0
Linear-179 [-1, 512] 2,097,664
ReLU-180 [-1, 512] 0
BatchNorm1d-181 [-1, 512] 1,024
Dropout-182 [-1, 512] 0
Linear-183 [-1, 6000] 3,078,000
================================================================
Total params: 28,692,912
Trainable params: 28,692,912
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 286.75
Params size (MB): 109.45
Estimated Total Size (MB): 396.78
----------------------------------------------------------------
'''
class AdaptiveConcatPool2d(nn.Module):
"""
Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`.
Source: Fastai. This code was taken from the fastai library at url
https://github.com/fastai/fastai/blob/master/fastai/layers.py#L176
"""
def __init__(self, sz=None):
"Output will be 2*sz or 2 if sz is None"
super().__init__()
self.output_size = sz or 1
self.ap = nn.AdaptiveAvgPool2d(self.output_size)
self.mp = nn.AdaptiveMaxPool2d(self.output_size)
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
class Flatten(nn.Module):
"""
Flatten `x` to a single dimension. Adapted from fastai's Flatten() layer,
at https://github.com/fastai/fastai/blob/master/fastai/layers.py#L25
"""
def __init__(self): super().__init__()
def forward(self, x): return x.view(x.size(0), -1)
def bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0., actn=None):
"""
Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`.
Adapted from Fastai at https://github.com/fastai/fastai/blob/master/fastai/layers.py#L44
"""
layers = [nn.BatchNorm1d(n_in)] if bn else []
if p != 0: layers.append(nn.Dropout(p))
layers.append(nn.Linear(n_in, n_out))
if actn is not None: layers.append(actn)
return layers
def create_head(top_n_tags, nf, ps=0.5):
nc = top_n_tags
lin_ftrs = [nf, 512, nc]
p1 = 0.25 # dropout for second last layer
p2 = 0.5 # dropout for last layer
actns = [nn.ReLU(inplace=True),] + [None]
pool = AdaptiveConcatPool2d()
layers = [pool, Flatten()]
layers += [
*bn_drop_lin(lin_ftrs[0], lin_ftrs[1], True, p1, nn.ReLU(inplace=True)),
*bn_drop_lin(lin_ftrs[1], lin_ftrs[2], True, p2)
]
return nn.Sequential(*layers)
def _resnet(base_arch, top_n, **kwargs):
cut = -2
s = base_arch(pretrained=False, **kwargs)
body = nn.Sequential(*list(s.children())[:cut])
if base_arch in [models.resnet18, models.resnet34]:
num_features_model = 512
elif base_arch in [models.resnet50, models.resnet101]:
num_features_model = 2048
nf = num_features_model * 2
nc = top_n
# head = create_head(nc, nf)
model = body # nn.Sequential(body, head)
return model
def resnet50(pretrained=True, progress=True, top_n=6000, **kwargs):
r"""
Resnet50 model trained on the full Danbooru2018 dataset's top 6000 tags
Args:
pretrained (bool): kwargs, load pretrained weights into the model.
top_n (int): kwargs, pick to load the model for predicting the top `n` tags,
currently only supports top_n=6000.
"""
model = _resnet(models.resnet50, top_n, **kwargs) # Take Resnet without the head (we don't care about final FC layers)
if pretrained:
if top_n == 6000:
state = torch.hub.load_state_dict_from_url("https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth",
progress=progress)
old_keys = [key for key in state]
for old_key in old_keys:
if old_key[0] == '0':
new_key = old_key[2:]
state[new_key] = state[old_key]
del state[old_key]
elif old_key[0] == '1':
del state[old_key]
model.load_state_dict(state)
else:
raise ValueError("Sorry, the resnet50 model only supports the top-6000 tags \
at the moment")
return model
class resnet50_Extractor(nn.Module):
"""ResNet50 network for feature extraction.
"""
def get_activation(self, name):
def hook(model, input, output):
self.activation[name] = output.detach()
return hook
def __init__(self,
model,
layer_labels,
use_input_norm=True,
range_norm=False,
requires_grad=False
):
super(resnet50_Extractor, self).__init__()
self.model = model
self.use_input_norm = use_input_norm
self.range_norm = range_norm
self.layer_labels = layer_labels
self.activation = {}
# Extract needed features
for layer_label in layer_labels:
elements = layer_label.split('_')
if len(elements) == 1:
# modified_net[layer_label] = getattr(model, elements[0])
getattr(self.model, elements[0]).register_forward_hook(self.get_activation(layer_label))
else:
body_layer = self.model
for element in elements[:-1]:
# Iterate until the last element
assert(isinstance(int(element), int))
body_layer = body_layer[int(element)]
getattr(body_layer, elements[-1]).register_forward_hook(self.get_activation(layer_label))
# Set as evaluation
if not requires_grad:
self.model.eval()
for param in self.parameters():
param.requires_grad = False
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [0, 1]
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.range_norm:
x = (x + 1) / 2
if self.use_input_norm:
x = (x - self.mean) / self.std
# Execute model first
output = self.model(x) # Zomby input
# Extract the layers we need
store = {}
for layer_label in self.layer_labels:
store[layer_label] = self.activation[layer_label]
return store
class Anime_PerceptualLoss(nn.Module):
"""Anime Perceptual loss
Args:
layer_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'conv5_4': 1.}, which means the conv5_4
feature layer (before relu5_4) will be extracted with weight
1.0 in calculating losses.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
"""
def __init__(self,
layer_weights,
perceptual_weight=1.0,
criterion='l1'):
super(Anime_PerceptualLoss, self).__init__()
model = resnet50()
self.perceptual_weight = perceptual_weight
self.layer_weights = layer_weights
self.layer_labels = layer_weights.keys()
self.resnet50 = resnet50_Extractor(model, self.layer_labels).cuda()
if criterion == 'l1':
self.criterion = torch.nn.L1Loss()
else:
raise NotImplementedError("We don't support such criterion loss in perceptual loss")
def forward(self, gen, gt):
"""Forward function.
Args:
gen (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
# extract vgg features
gen_features = self.resnet50(gen)
gt_features = self.resnet50(gt.detach())
temp_store = []
# calculate perceptual loss
if self.perceptual_weight > 0:
percep_loss = 0
for idx, k in enumerate(gen_features.keys()):
raw_comparison = self.criterion(gen_features[k], gt_features[k])
percep_loss += raw_comparison * self.layer_weights[k]
# print("layer" + str(idx) + " has loss " + str(raw_comparison.cpu().numpy()))
# temp_store.append(float(raw_comparison.cpu().numpy()))
percep_loss *= self.perceptual_weight
else:
percep_loss = None
# 第一个是为了Debug purpose
if len(temp_store) != 0:
return temp_store, percep_loss
else:
return percep_loss
if __name__ == "__main__":
import torchvision.transforms as transforms
import cv2
import collections
loss = Anime_PerceptualLoss({"0": 0.5, "4_2_conv3": 20, "5_3_conv3": 30, "6_5_conv3": 1, "7_2_conv3": 1}).cuda()
store = collections.defaultdict(list)
for img_name in sorted(os.listdir('datasets/train_gen/')):
gen = transforms.ToTensor()(cv2.imread('datasets/train_gen/'+img_name)).cuda()
gt = transforms.ToTensor()(cv2.imread('datasets/train_hr_anime_usm/'+img_name)).cuda()
temp_store, _ = loss(gen, gt)
for idx in range(len(temp_store)):
store[idx].append(temp_store[idx])
for idx in range(len(store)):
print("Average layer" + str(idx) + " has loss " + str(sum(store[idx]) / len(store[idx])))
# model = loss.vgg
# pytorch_total_params = sum(p.numel() for p in model.parameters())
# print(f"Perceptual VGG has param {pytorch_total_params//1000000} M params")