from typing import Optional, Union import torch from torch import device import torch.nn as nn import torch.nn.functional as F import torchvision.models as tvm import gc class ResNet50(nn.Module): def __init__( self, pretrained=False, high_res=False, weights=None, dilation=None, freeze_bn=True, anti_aliased=False, early_exit=False, amp=False, ) -> None: super().__init__() if dilation is None: dilation = [False, False, False] if anti_aliased: pass else: if weights is not None: self.net = tvm.resnet50( weights=weights, replace_stride_with_dilation=dilation ) else: self.net = tvm.resnet50( pretrained=pretrained, replace_stride_with_dilation=dilation ) self.high_res = high_res self.freeze_bn = freeze_bn self.early_exit = early_exit self.amp = amp if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): self.amp_dtype = torch.bfloat16 else: self.amp_dtype = torch.float16 def forward(self, x, **kwargs): with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype): net = self.net feats = {1: x} x = net.conv1(x) x = net.bn1(x) x = net.relu(x) feats[2] = x x = net.maxpool(x) x = net.layer1(x) feats[4] = x x = net.layer2(x) feats[8] = x if self.early_exit: return feats x = net.layer3(x) feats[16] = x x = net.layer4(x) feats[32] = x return feats def train(self, mode=True): super().train(mode) if self.freeze_bn: for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() pass class VGG19(nn.Module): def __init__(self, pretrained=False, amp=False) -> None: super().__init__() self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40]) self.amp = amp if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): self.amp_dtype = torch.bfloat16 else: self.amp_dtype = torch.float16 def forward(self, x, **kwargs): with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype): feats = {} scale = 1 for layer in self.layers: if isinstance(layer, nn.MaxPool2d): feats[scale] = x scale = scale * 2 x = layer(x) return feats class CNNandDinov2(nn.Module): def __init__(self, cnn_kwargs=None, amp=False, use_vgg=False, dinov2_weights=None): super().__init__() if dinov2_weights is None: dinov2_weights = torch.hub.load_state_dict_from_url( "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu", ) from .transformer import vit_large vit_kwargs = dict( img_size=518, patch_size=14, init_values=1.0, ffn_layer="mlp", block_chunks=0, ) dinov2_vitl14 = vit_large(**vit_kwargs).eval() dinov2_vitl14.load_state_dict(dinov2_weights) cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {} if not use_vgg: self.cnn = ResNet50(**cnn_kwargs) else: self.cnn = VGG19(**cnn_kwargs) self.amp = amp if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): self.amp_dtype = torch.bfloat16 else: self.amp_dtype = torch.float16 if self.amp: dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype) self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP def train(self, mode: bool = True): return self.cnn.train(mode) def forward(self, x, upsample=False): B, C, H, W = x.shape feature_pyramid = self.cnn(x) if not upsample: with torch.no_grad(): if self.dinov2_vitl14[0].device != x.device: self.dinov2_vitl14[0] = ( self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype) ) dinov2_features_16 = self.dinov2_vitl14[0].forward_features( x.to(self.amp_dtype) ) features_16 = ( dinov2_features_16["x_norm_patchtokens"] .permute(0, 2, 1) .reshape(B, 1024, H // 14, W // 14) ) del dinov2_features_16 feature_pyramid[16] = features_16 return feature_pyramid