Spaces:
Runtime error
Runtime error
""" | |
Neighborhood Attention Transformer. | |
https://arxiv.org/abs/2204.07143 | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import torch | |
import torchvision | |
import torch.nn as nn | |
from timm.models.layers import trunc_normal_, DropPath | |
from timm.models.registry import register_model | |
IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
IMAGENET_STD = [0.229, 0.224, 0.225] | |
class VGGPerceptualLoss(torch.nn.Module): | |
def __init__(self, resize=True): | |
super(VGGPerceptualLoss, self).__init__() | |
blocks = [] | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) | |
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) | |
for bl in blocks: | |
for p in bl.parameters(): | |
p.requires_grad = False | |
self.blocks = torch.nn.ModuleList(blocks) | |
self.transform = torch.nn.functional.interpolate | |
self.resize = resize | |
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
def forward(self, input, appearance_layers=[0,1,2,3]): | |
if input.shape[1] != 3: | |
input = input.repeat(1, 3, 1, 1) | |
target = target.repeat(1, 3, 1, 1) | |
input = (input-self.mean) / self.std | |
if self.resize: | |
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) | |
x = input | |
feats = [] | |
for i, block in enumerate(self.blocks): | |
x = block(x) | |
if i in appearance_layers: | |
feats.append(x) | |
return feats | |
class DINOv2(torch.nn.Module): | |
def __init__(self, resize=True, size=224, model_type='dinov2_vitl14'): | |
super(DINOv2, self).__init__() | |
self.size=size | |
self.resize = resize | |
self.transform = torch.nn.functional.interpolate | |
self.model = torch.hub.load('facebookresearch/dinov2', model_type) | |
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
def forward(self, input, appearance_layers=[1,2]): | |
if input.shape[1] != 3: | |
input = input.repeat(1, 3, 1, 1) | |
target = target.repeat(1, 3, 1, 1) | |
if self.resize: | |
input = self.transform(input, mode='bicubic', size=(self.size, self.size), align_corners=False) | |
# mean = torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1).to(input.device) | |
# std = torch.tensor(IMAGENET_STD).view(1, 3, 1, 1).to(input.device) | |
input = (input-self.mean) / self.std | |
feats = self.model.get_intermediate_layers(input, self.model.n_blocks, reshape=True) | |
feats = [f.detach() for f in feats] | |
return feats |