Spaces:
Running
on
T4
Running
on
T4
import os, sys | |
from collections import OrderedDict | |
import cv2 | |
import torch.nn as nn | |
import torch | |
from torchvision import models | |
import torchvision.transforms as transforms | |
''' | |
---------------------------------------------------------------- | |
Layer (type) Output Shape Param # | |
================================================================ | |
Conv2d-1 [-1, 64, 112, 112] 9,408 | |
BatchNorm2d-2 [-1, 64, 112, 112] 128 | |
ReLU-3 [-1, 64, 112, 112] 0 | |
MaxPool2d-4 [-1, 64, 56, 56] 0 | |
Conv2d-5 [-1, 64, 56, 56] 4,096 | |
BatchNorm2d-6 [-1, 64, 56, 56] 128 | |
ReLU-7 [-1, 64, 56, 56] 0 | |
Conv2d-8 [-1, 64, 56, 56] 36,864 | |
BatchNorm2d-9 [-1, 64, 56, 56] 128 | |
ReLU-10 [-1, 64, 56, 56] 0 | |
Conv2d-11 [-1, 256, 56, 56] 16,384 | |
BatchNorm2d-12 [-1, 256, 56, 56] 512 | |
Conv2d-13 [-1, 256, 56, 56] 16,384 | |
BatchNorm2d-14 [-1, 256, 56, 56] 512 | |
ReLU-15 [-1, 256, 56, 56] 0 | |
Bottleneck-16 [-1, 256, 56, 56] 0 | |
Conv2d-17 [-1, 64, 56, 56] 16,384 | |
BatchNorm2d-18 [-1, 64, 56, 56] 128 | |
ReLU-19 [-1, 64, 56, 56] 0 | |
Conv2d-20 [-1, 64, 56, 56] 36,864 | |
BatchNorm2d-21 [-1, 64, 56, 56] 128 | |
ReLU-22 [-1, 64, 56, 56] 0 | |
Conv2d-23 [-1, 256, 56, 56] 16,384 | |
BatchNorm2d-24 [-1, 256, 56, 56] 512 | |
ReLU-25 [-1, 256, 56, 56] 0 | |
Bottleneck-26 [-1, 256, 56, 56] 0 | |
Conv2d-27 [-1, 64, 56, 56] 16,384 | |
BatchNorm2d-28 [-1, 64, 56, 56] 128 | |
ReLU-29 [-1, 64, 56, 56] 0 | |
Conv2d-30 [-1, 64, 56, 56] 36,864 | |
BatchNorm2d-31 [-1, 64, 56, 56] 128 | |
ReLU-32 [-1, 64, 56, 56] 0 | |
Conv2d-33 [-1, 256, 56, 56] 16,384 | |
BatchNorm2d-34 [-1, 256, 56, 56] 512 | |
ReLU-35 [-1, 256, 56, 56] 0 | |
Bottleneck-36 [-1, 256, 56, 56] 0 | |
Conv2d-37 [-1, 128, 56, 56] 32,768 | |
BatchNorm2d-38 [-1, 128, 56, 56] 256 | |
ReLU-39 [-1, 128, 56, 56] 0 | |
Conv2d-40 [-1, 128, 28, 28] 147,456 | |
BatchNorm2d-41 [-1, 128, 28, 28] 256 | |
ReLU-42 [-1, 128, 28, 28] 0 | |
Conv2d-43 [-1, 512, 28, 28] 65,536 | |
BatchNorm2d-44 [-1, 512, 28, 28] 1,024 | |
Conv2d-45 [-1, 512, 28, 28] 131,072 | |
BatchNorm2d-46 [-1, 512, 28, 28] 1,024 | |
ReLU-47 [-1, 512, 28, 28] 0 | |
Bottleneck-48 [-1, 512, 28, 28] 0 | |
Conv2d-49 [-1, 128, 28, 28] 65,536 | |
BatchNorm2d-50 [-1, 128, 28, 28] 256 | |
ReLU-51 [-1, 128, 28, 28] 0 | |
Conv2d-52 [-1, 128, 28, 28] 147,456 | |
BatchNorm2d-53 [-1, 128, 28, 28] 256 | |
ReLU-54 [-1, 128, 28, 28] 0 | |
Conv2d-55 [-1, 512, 28, 28] 65,536 | |
BatchNorm2d-56 [-1, 512, 28, 28] 1,024 | |
ReLU-57 [-1, 512, 28, 28] 0 | |
Bottleneck-58 [-1, 512, 28, 28] 0 | |
Conv2d-59 [-1, 128, 28, 28] 65,536 | |
BatchNorm2d-60 [-1, 128, 28, 28] 256 | |
ReLU-61 [-1, 128, 28, 28] 0 | |
Conv2d-62 [-1, 128, 28, 28] 147,456 | |
BatchNorm2d-63 [-1, 128, 28, 28] 256 | |
ReLU-64 [-1, 128, 28, 28] 0 | |
Conv2d-65 [-1, 512, 28, 28] 65,536 | |
BatchNorm2d-66 [-1, 512, 28, 28] 1,024 | |
ReLU-67 [-1, 512, 28, 28] 0 | |
Bottleneck-68 [-1, 512, 28, 28] 0 | |
Conv2d-69 [-1, 128, 28, 28] 65,536 | |
BatchNorm2d-70 [-1, 128, 28, 28] 256 | |
ReLU-71 [-1, 128, 28, 28] 0 | |
Conv2d-72 [-1, 128, 28, 28] 147,456 | |
BatchNorm2d-73 [-1, 128, 28, 28] 256 | |
ReLU-74 [-1, 128, 28, 28] 0 | |
Conv2d-75 [-1, 512, 28, 28] 65,536 | |
BatchNorm2d-76 [-1, 512, 28, 28] 1,024 | |
ReLU-77 [-1, 512, 28, 28] 0 | |
Bottleneck-78 [-1, 512, 28, 28] 0 | |
Conv2d-79 [-1, 256, 28, 28] 131,072 | |
BatchNorm2d-80 [-1, 256, 28, 28] 512 | |
ReLU-81 [-1, 256, 28, 28] 0 | |
Conv2d-82 [-1, 256, 14, 14] 589,824 | |
BatchNorm2d-83 [-1, 256, 14, 14] 512 | |
ReLU-84 [-1, 256, 14, 14] 0 | |
Conv2d-85 [-1, 1024, 14, 14] 262,144 | |
BatchNorm2d-86 [-1, 1024, 14, 14] 2,048 | |
Conv2d-87 [-1, 1024, 14, 14] 524,288 | |
BatchNorm2d-88 [-1, 1024, 14, 14] 2,048 | |
ReLU-89 [-1, 1024, 14, 14] 0 | |
Bottleneck-90 [-1, 1024, 14, 14] 0 | |
Conv2d-91 [-1, 256, 14, 14] 262,144 | |
BatchNorm2d-92 [-1, 256, 14, 14] 512 | |
ReLU-93 [-1, 256, 14, 14] 0 | |
Conv2d-94 [-1, 256, 14, 14] 589,824 | |
BatchNorm2d-95 [-1, 256, 14, 14] 512 | |
ReLU-96 [-1, 256, 14, 14] 0 | |
Conv2d-97 [-1, 1024, 14, 14] 262,144 | |
BatchNorm2d-98 [-1, 1024, 14, 14] 2,048 | |
ReLU-99 [-1, 1024, 14, 14] 0 | |
Bottleneck-100 [-1, 1024, 14, 14] 0 | |
Conv2d-101 [-1, 256, 14, 14] 262,144 | |
BatchNorm2d-102 [-1, 256, 14, 14] 512 | |
ReLU-103 [-1, 256, 14, 14] 0 | |
Conv2d-104 [-1, 256, 14, 14] 589,824 | |
BatchNorm2d-105 [-1, 256, 14, 14] 512 | |
ReLU-106 [-1, 256, 14, 14] 0 | |
Conv2d-107 [-1, 1024, 14, 14] 262,144 | |
BatchNorm2d-108 [-1, 1024, 14, 14] 2,048 | |
ReLU-109 [-1, 1024, 14, 14] 0 | |
Bottleneck-110 [-1, 1024, 14, 14] 0 | |
Conv2d-111 [-1, 256, 14, 14] 262,144 | |
BatchNorm2d-112 [-1, 256, 14, 14] 512 | |
ReLU-113 [-1, 256, 14, 14] 0 | |
Conv2d-114 [-1, 256, 14, 14] 589,824 | |
BatchNorm2d-115 [-1, 256, 14, 14] 512 | |
ReLU-116 [-1, 256, 14, 14] 0 | |
Conv2d-117 [-1, 1024, 14, 14] 262,144 | |
BatchNorm2d-118 [-1, 1024, 14, 14] 2,048 | |
ReLU-119 [-1, 1024, 14, 14] 0 | |
Bottleneck-120 [-1, 1024, 14, 14] 0 | |
Conv2d-121 [-1, 256, 14, 14] 262,144 | |
BatchNorm2d-122 [-1, 256, 14, 14] 512 | |
ReLU-123 [-1, 256, 14, 14] 0 | |
Conv2d-124 [-1, 256, 14, 14] 589,824 | |
BatchNorm2d-125 [-1, 256, 14, 14] 512 | |
ReLU-126 [-1, 256, 14, 14] 0 | |
Conv2d-127 [-1, 1024, 14, 14] 262,144 | |
BatchNorm2d-128 [-1, 1024, 14, 14] 2,048 | |
ReLU-129 [-1, 1024, 14, 14] 0 | |
Bottleneck-130 [-1, 1024, 14, 14] 0 | |
Conv2d-131 [-1, 256, 14, 14] 262,144 | |
BatchNorm2d-132 [-1, 256, 14, 14] 512 | |
ReLU-133 [-1, 256, 14, 14] 0 | |
Conv2d-134 [-1, 256, 14, 14] 589,824 | |
BatchNorm2d-135 [-1, 256, 14, 14] 512 | |
ReLU-136 [-1, 256, 14, 14] 0 | |
Conv2d-137 [-1, 1024, 14, 14] 262,144 | |
BatchNorm2d-138 [-1, 1024, 14, 14] 2,048 | |
ReLU-139 [-1, 1024, 14, 14] 0 | |
Bottleneck-140 [-1, 1024, 14, 14] 0 | |
Conv2d-141 [-1, 512, 14, 14] 524,288 | |
BatchNorm2d-142 [-1, 512, 14, 14] 1,024 | |
ReLU-143 [-1, 512, 14, 14] 0 | |
Conv2d-144 [-1, 512, 7, 7] 2,359,296 | |
BatchNorm2d-145 [-1, 512, 7, 7] 1,024 | |
ReLU-146 [-1, 512, 7, 7] 0 | |
Conv2d-147 [-1, 2048, 7, 7] 1,048,576 | |
BatchNorm2d-148 [-1, 2048, 7, 7] 4,096 | |
Conv2d-149 [-1, 2048, 7, 7] 2,097,152 | |
BatchNorm2d-150 [-1, 2048, 7, 7] 4,096 | |
ReLU-151 [-1, 2048, 7, 7] 0 | |
Bottleneck-152 [-1, 2048, 7, 7] 0 | |
Conv2d-153 [-1, 512, 7, 7] 1,048,576 | |
BatchNorm2d-154 [-1, 512, 7, 7] 1,024 | |
ReLU-155 [-1, 512, 7, 7] 0 | |
Conv2d-156 [-1, 512, 7, 7] 2,359,296 | |
BatchNorm2d-157 [-1, 512, 7, 7] 1,024 | |
ReLU-158 [-1, 512, 7, 7] 0 | |
Conv2d-159 [-1, 2048, 7, 7] 1,048,576 | |
BatchNorm2d-160 [-1, 2048, 7, 7] 4,096 | |
ReLU-161 [-1, 2048, 7, 7] 0 | |
Bottleneck-162 [-1, 2048, 7, 7] 0 | |
Conv2d-163 [-1, 512, 7, 7] 1,048,576 | |
BatchNorm2d-164 [-1, 512, 7, 7] 1,024 | |
ReLU-165 [-1, 512, 7, 7] 0 | |
Conv2d-166 [-1, 512, 7, 7] 2,359,296 | |
BatchNorm2d-167 [-1, 512, 7, 7] 1,024 | |
ReLU-168 [-1, 512, 7, 7] 0 | |
Conv2d-169 [-1, 2048, 7, 7] 1,048,576 | |
BatchNorm2d-170 [-1, 2048, 7, 7] 4,096 | |
ReLU-171 [-1, 2048, 7, 7] 0 | |
Bottleneck-172 [-1, 2048, 7, 7] 0 | |
AdaptiveMaxPool2d-173 [-1, 2048, 1, 1] 0 | |
AdaptiveAvgPool2d-174 [-1, 2048, 1, 1] 0 | |
AdaptiveConcatPool2d-175 [-1, 4096, 1, 1] 0 | |
Flatten-176 [-1, 4096] 0 | |
BatchNorm1d-177 [-1, 4096] 8,192 | |
Dropout-178 [-1, 4096] 0 | |
Linear-179 [-1, 512] 2,097,664 | |
ReLU-180 [-1, 512] 0 | |
BatchNorm1d-181 [-1, 512] 1,024 | |
Dropout-182 [-1, 512] 0 | |
Linear-183 [-1, 6000] 3,078,000 | |
================================================================ | |
Total params: 28,692,912 | |
Trainable params: 28,692,912 | |
Non-trainable params: 0 | |
---------------------------------------------------------------- | |
Input size (MB): 0.57 | |
Forward/backward pass size (MB): 286.75 | |
Params size (MB): 109.45 | |
Estimated Total Size (MB): 396.78 | |
---------------------------------------------------------------- | |
''' | |
class AdaptiveConcatPool2d(nn.Module): | |
""" | |
Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`. | |
Source: Fastai. This code was taken from the fastai library at url | |
https://github.com/fastai/fastai/blob/master/fastai/layers.py#L176 | |
""" | |
def __init__(self, sz=None): | |
"Output will be 2*sz or 2 if sz is None" | |
super().__init__() | |
self.output_size = sz or 1 | |
self.ap = nn.AdaptiveAvgPool2d(self.output_size) | |
self.mp = nn.AdaptiveMaxPool2d(self.output_size) | |
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) | |
class Flatten(nn.Module): | |
""" | |
Flatten `x` to a single dimension. Adapted from fastai's Flatten() layer, | |
at https://github.com/fastai/fastai/blob/master/fastai/layers.py#L25 | |
""" | |
def __init__(self): super().__init__() | |
def forward(self, x): return x.view(x.size(0), -1) | |
def bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0., actn=None): | |
""" | |
Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`. | |
Adapted from Fastai at https://github.com/fastai/fastai/blob/master/fastai/layers.py#L44 | |
""" | |
layers = [nn.BatchNorm1d(n_in)] if bn else [] | |
if p != 0: layers.append(nn.Dropout(p)) | |
layers.append(nn.Linear(n_in, n_out)) | |
if actn is not None: layers.append(actn) | |
return layers | |
def create_head(top_n_tags, nf, ps=0.5): | |
nc = top_n_tags | |
lin_ftrs = [nf, 512, nc] | |
p1 = 0.25 # dropout for second last layer | |
p2 = 0.5 # dropout for last layer | |
actns = [nn.ReLU(inplace=True),] + [None] | |
pool = AdaptiveConcatPool2d() | |
layers = [pool, Flatten()] | |
layers += [ | |
*bn_drop_lin(lin_ftrs[0], lin_ftrs[1], True, p1, nn.ReLU(inplace=True)), | |
*bn_drop_lin(lin_ftrs[1], lin_ftrs[2], True, p2) | |
] | |
return nn.Sequential(*layers) | |
def _resnet(base_arch, top_n, **kwargs): | |
cut = -2 | |
s = base_arch(pretrained=False, **kwargs) | |
body = nn.Sequential(*list(s.children())[:cut]) | |
if base_arch in [models.resnet18, models.resnet34]: | |
num_features_model = 512 | |
elif base_arch in [models.resnet50, models.resnet101]: | |
num_features_model = 2048 | |
nf = num_features_model * 2 | |
nc = top_n | |
# head = create_head(nc, nf) | |
model = body # nn.Sequential(body, head) | |
return model | |
def resnet50(pretrained=True, progress=True, top_n=6000, **kwargs): | |
r""" | |
Resnet50 model trained on the full Danbooru2018 dataset's top 6000 tags | |
Args: | |
pretrained (bool): kwargs, load pretrained weights into the model. | |
top_n (int): kwargs, pick to load the model for predicting the top `n` tags, | |
currently only supports top_n=6000. | |
""" | |
model = _resnet(models.resnet50, top_n, **kwargs) # Take Resnet without the head (we don't care about final FC layers) | |
if pretrained: | |
if top_n == 6000: | |
state = torch.hub.load_state_dict_from_url("https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth", | |
progress=progress) | |
old_keys = [key for key in state] | |
for old_key in old_keys: | |
if old_key[0] == '0': | |
new_key = old_key[2:] | |
state[new_key] = state[old_key] | |
del state[old_key] | |
elif old_key[0] == '1': | |
del state[old_key] | |
model.load_state_dict(state) | |
else: | |
raise ValueError("Sorry, the resnet50 model only supports the top-6000 tags \ | |
at the moment") | |
return model | |
class resnet50_Extractor(nn.Module): | |
"""ResNet50 network for feature extraction. | |
""" | |
def get_activation(self, name): | |
def hook(model, input, output): | |
self.activation[name] = output.detach() | |
return hook | |
def __init__(self, | |
model, | |
layer_labels, | |
use_input_norm=True, | |
range_norm=False, | |
requires_grad=False | |
): | |
super(resnet50_Extractor, self).__init__() | |
self.model = model | |
self.use_input_norm = use_input_norm | |
self.range_norm = range_norm | |
self.layer_labels = layer_labels | |
self.activation = {} | |
# Extract needed features | |
for layer_label in layer_labels: | |
elements = layer_label.split('_') | |
if len(elements) == 1: | |
# modified_net[layer_label] = getattr(model, elements[0]) | |
getattr(self.model, elements[0]).register_forward_hook(self.get_activation(layer_label)) | |
else: | |
body_layer = self.model | |
for element in elements[:-1]: | |
# Iterate until the last element | |
assert(isinstance(int(element), int)) | |
body_layer = body_layer[int(element)] | |
getattr(body_layer, elements[-1]).register_forward_hook(self.get_activation(layer_label)) | |
# Set as evaluation | |
if not requires_grad: | |
self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
if self.use_input_norm: | |
# the mean is for image with range [0, 1] | |
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
# the std is for image with range [0, 1] | |
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
def forward(self, x): | |
"""Forward function. | |
Args: | |
x (Tensor): Input tensor with shape (n, c, h, w). | |
Returns: | |
Tensor: Forward results. | |
""" | |
if self.range_norm: | |
x = (x + 1) / 2 | |
if self.use_input_norm: | |
x = (x - self.mean) / self.std | |
# Execute model first | |
output = self.model(x) # Zomby input | |
# Extract the layers we need | |
store = {} | |
for layer_label in self.layer_labels: | |
store[layer_label] = self.activation[layer_label] | |
return store | |
class Anime_PerceptualLoss(nn.Module): | |
"""Anime Perceptual loss | |
Args: | |
layer_weights (dict): The weight for each layer of vgg feature. | |
Here is an example: {'conv5_4': 1.}, which means the conv5_4 | |
feature layer (before relu5_4) will be extracted with weight | |
1.0 in calculating losses. | |
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual | |
loss will be calculated and the loss will multiplied by the | |
weight. Default: 1.0. | |
criterion (str): Criterion used for perceptual loss. Default: 'l1'. | |
""" | |
def __init__(self, | |
layer_weights, | |
perceptual_weight=1.0, | |
criterion='l1'): | |
super(Anime_PerceptualLoss, self).__init__() | |
model = resnet50() | |
self.perceptual_weight = perceptual_weight | |
self.layer_weights = layer_weights | |
self.layer_labels = layer_weights.keys() | |
self.resnet50 = resnet50_Extractor(model, self.layer_labels).cuda() | |
if criterion == 'l1': | |
self.criterion = torch.nn.L1Loss() | |
else: | |
raise NotImplementedError("We don't support such criterion loss in perceptual loss") | |
def forward(self, gen, gt): | |
"""Forward function. | |
Args: | |
gen (Tensor): Input tensor with shape (n, c, h, w). | |
gt (Tensor): Ground-truth tensor with shape (n, c, h, w). | |
Returns: | |
Tensor: Forward results. | |
""" | |
# extract vgg features | |
gen_features = self.resnet50(gen) | |
gt_features = self.resnet50(gt.detach()) | |
temp_store = [] | |
# calculate perceptual loss | |
if self.perceptual_weight > 0: | |
percep_loss = 0 | |
for idx, k in enumerate(gen_features.keys()): | |
raw_comparison = self.criterion(gen_features[k], gt_features[k]) | |
percep_loss += raw_comparison * self.layer_weights[k] | |
# print("layer" + str(idx) + " has loss " + str(raw_comparison.cpu().numpy())) | |
# temp_store.append(float(raw_comparison.cpu().numpy())) | |
percep_loss *= self.perceptual_weight | |
else: | |
percep_loss = None | |
# 第一个是为了Debug purpose | |
if len(temp_store) != 0: | |
return temp_store, percep_loss | |
else: | |
return percep_loss | |
if __name__ == "__main__": | |
import torchvision.transforms as transforms | |
import cv2 | |
import collections | |
loss = Anime_PerceptualLoss({"0": 0.5, "4_2_conv3": 20, "5_3_conv3": 30, "6_5_conv3": 1, "7_2_conv3": 1}).cuda() | |
store = collections.defaultdict(list) | |
for img_name in sorted(os.listdir('datasets/train_gen/')): | |
gen = transforms.ToTensor()(cv2.imread('datasets/train_gen/'+img_name)).cuda() | |
gt = transforms.ToTensor()(cv2.imread('datasets/train_hr_anime_usm/'+img_name)).cuda() | |
temp_store, _ = loss(gen, gt) | |
for idx in range(len(temp_store)): | |
store[idx].append(temp_store[idx]) | |
for idx in range(len(store)): | |
print("Average layer" + str(idx) + " has loss " + str(sum(store[idx]) / len(store[idx]))) | |
# model = loss.vgg | |
# pytorch_total_params = sum(p.numel() for p in model.parameters()) | |
# print(f"Perceptual VGG has param {pytorch_total_params//1000000} M params") |