Vincentqyw
fix: roma
358ab8f
raw
history blame
5.04 kB
from typing import Optional, Union
import torch
from torch import device
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tvm
import gc
class ResNet50(nn.Module):
def __init__(
self,
pretrained=False,
high_res=False,
weights=None,
dilation=None,
freeze_bn=True,
anti_aliased=False,
early_exit=False,
amp=False,
) -> None:
super().__init__()
if dilation is None:
dilation = [False, False, False]
if anti_aliased:
pass
else:
if weights is not None:
self.net = tvm.resnet50(
weights=weights, replace_stride_with_dilation=dilation
)
else:
self.net = tvm.resnet50(
pretrained=pretrained, replace_stride_with_dilation=dilation
)
self.high_res = high_res
self.freeze_bn = freeze_bn
self.early_exit = early_exit
self.amp = amp
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
self.amp_dtype = torch.bfloat16
else:
self.amp_dtype = torch.float16
def forward(self, x, **kwargs):
with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
net = self.net
feats = {1: x}
x = net.conv1(x)
x = net.bn1(x)
x = net.relu(x)
feats[2] = x
x = net.maxpool(x)
x = net.layer1(x)
feats[4] = x
x = net.layer2(x)
feats[8] = x
if self.early_exit:
return feats
x = net.layer3(x)
feats[16] = x
x = net.layer4(x)
feats[32] = x
return feats
def train(self, mode=True):
super().train(mode)
if self.freeze_bn:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
pass
class VGG19(nn.Module):
def __init__(self, pretrained=False, amp=False) -> None:
super().__init__()
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
self.amp = amp
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
self.amp_dtype = torch.bfloat16
else:
self.amp_dtype = torch.float16
def forward(self, x, **kwargs):
with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
feats = {}
scale = 1
for layer in self.layers:
if isinstance(layer, nn.MaxPool2d):
feats[scale] = x
scale = scale * 2
x = layer(x)
return feats
class CNNandDinov2(nn.Module):
def __init__(self, cnn_kwargs=None, amp=False, use_vgg=False, dinov2_weights=None):
super().__init__()
if dinov2_weights is None:
dinov2_weights = torch.hub.load_state_dict_from_url(
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
map_location="cpu",
)
from .transformer import vit_large
vit_kwargs = dict(
img_size=518,
patch_size=14,
init_values=1.0,
ffn_layer="mlp",
block_chunks=0,
)
dinov2_vitl14 = vit_large(**vit_kwargs).eval()
dinov2_vitl14.load_state_dict(dinov2_weights)
cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
if not use_vgg:
self.cnn = ResNet50(**cnn_kwargs)
else:
self.cnn = VGG19(**cnn_kwargs)
self.amp = amp
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
self.amp_dtype = torch.bfloat16
else:
self.amp_dtype = torch.float16
if self.amp:
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
def train(self, mode: bool = True):
return self.cnn.train(mode)
def forward(self, x, upsample=False):
B, C, H, W = x.shape
feature_pyramid = self.cnn(x)
if not upsample:
with torch.no_grad():
if self.dinov2_vitl14[0].device != x.device:
self.dinov2_vitl14[0] = (
self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
)
dinov2_features_16 = self.dinov2_vitl14[0].forward_features(
x.to(self.amp_dtype)
)
features_16 = (
dinov2_features_16["x_norm_patchtokens"]
.permute(0, 2, 1)
.reshape(B, 1024, H // 14, W // 14)
)
del dinov2_features_16
feature_pyramid[16] = features_16
return feature_pyramid