import torch import torch.nn as nn from PIL import Image from itertools import chain from torchvision import models from typing import Sequence from collections import OrderedDict def get_network(net_type: str = 'vgg'): if net_type == 'alex': return AlexNet() elif net_type == 'squeeze': return SqueezeNet() elif net_type == 'vgg': return VGG16() else: raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') def normalize_activation(x, eps=1e-10): norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) return x / (norm_factor + eps) class BaseNet(nn.Module): def __init__(self): super(BaseNet, self).__init__() # register buffer self.register_buffer( 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) self.register_buffer( 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) def set_requires_grad(self, state: bool): for param in chain(self.parameters(), self.buffers()): param.requires_grad = state def z_score(self, x: torch.Tensor): return (x - self.mean) / self.std def forward(self, x: torch.Tensor): x = self.z_score(x) output = [] for i, (_, layer) in enumerate(self.layers._modules.items(), 1): x = layer(x) if i in self.target_layers: output.append(normalize_activation(x)) if len(output) == len(self.target_layers): break return output class SqueezeNet(BaseNet): def __init__(self): super(SqueezeNet, self).__init__() self.layers = models.squeezenet1_1(True).features self.target_layers = [2, 5, 8, 10, 11, 12, 13] self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] self.set_requires_grad(False) class AlexNet(BaseNet): def __init__(self): super(AlexNet, self).__init__() self.layers = models.alexnet(True).features self.target_layers = [2, 5, 8, 10, 12] self.n_channels_list = [64, 192, 384, 256, 256] self.set_requires_grad(False) class VGG16(BaseNet): def __init__(self): super(VGG16, self).__init__() self.layers = models.vgg16(True).features self.target_layers = [4, 9, 16, 23, 30] self.n_channels_list = [64, 128, 256, 512, 512] self.set_requires_grad(False) class LinLayers(nn.ModuleList): def __init__(self, n_channels_list: Sequence[int]): super(LinLayers, self).__init__([ nn.Sequential( nn.Identity(), nn.Conv2d(nc, 1, 1, 1, 0, bias=False) ) for nc in n_channels_list ]) for param in self.parameters(): param.requires_grad = False def get_state_dict(net_type: str = 'alex', version: str = '0.1'): # build url url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ + f'master/lpips/weights/v{version}/{net_type}.pth' # download old_state_dict = torch.hub.load_state_dict_from_url( url, progress=True, map_location=None if torch.cuda.is_available() else torch.device('cpu') ) # rename keys new_state_dict = OrderedDict() for key, val in old_state_dict.items(): new_key = key new_key = new_key.replace('lin', '') new_key = new_key.replace('model.', '') new_state_dict[new_key] = val return new_state_dict class LPIPS(nn.Module): r"""Creates a criterion that measures Learned Perceptual Image Patch Similarity (LPIPS). Arguments: net_type (str): the network type to compare the features: 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. version (str): the version of LPIPS. Default: 0.1. """ def __init__(self, net_type: str = 'vgg', version: str = '0.1'): assert version in ['0.1'], 'v0.1 is only supported now' super(LPIPS, self).__init__() # pretrained network self.net = get_network(net_type).to("cuda") # linear layers self.lin = LinLayers(self.net.n_channels_list).to("cuda") self.lin.load_state_dict(get_state_dict(net_type, version)) def forward(self, x: torch.Tensor, y: torch.Tensor): feat_x, feat_y = self.net(x), self.net(y) diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] return torch.sum(torch.cat(res, 0)) / x.shape[0]