Vincentqyw
fix: roma
358ab8f
raw
history blame
6.25 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tvm
class ResNet18(nn.Module):
def __init__(self, pretrained=False) -> None:
super().__init__()
self.net = tvm.resnet18(pretrained=pretrained)
def forward(self, x):
self = self.net
x1 = x
x = self.conv1(x1)
x = self.bn1(x)
x2 = self.relu(x)
x = self.maxpool(x2)
x4 = self.layer1(x)
x8 = self.layer2(x4)
x16 = self.layer3(x8)
x32 = self.layer4(x16)
return {32: x32, 16: x16, 8: x8, 4: x4, 2: x2, 1: x1}
def train(self, mode=True):
super().train(mode)
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
pass
class ResNet50(nn.Module):
def __init__(
self,
pretrained=False,
high_res=False,
weights=None,
dilation=None,
freeze_bn=True,
anti_aliased=False,
) -> None:
super().__init__()
if dilation is None:
dilation = [False, False, False]
if anti_aliased:
pass
else:
if weights is not None:
self.net = tvm.resnet50(
weights=weights, replace_stride_with_dilation=dilation
)
else:
self.net = tvm.resnet50(
pretrained=pretrained, replace_stride_with_dilation=dilation
)
self.high_res = high_res
self.freeze_bn = freeze_bn
def forward(self, x):
net = self.net
feats = {1: x}
x = net.conv1(x)
x = net.bn1(x)
x = net.relu(x)
feats[2] = x
x = net.maxpool(x)
x = net.layer1(x)
feats[4] = x
x = net.layer2(x)
feats[8] = x
x = net.layer3(x)
feats[16] = x
x = net.layer4(x)
feats[32] = x
return feats
def train(self, mode=True):
super().train(mode)
if self.freeze_bn:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
pass
class ResNet101(nn.Module):
def __init__(self, pretrained=False, high_res=False, weights=None) -> None:
super().__init__()
if weights is not None:
self.net = tvm.resnet101(weights=weights)
else:
self.net = tvm.resnet101(pretrained=pretrained)
self.high_res = high_res
self.scale_factor = 1 if not high_res else 1.5
def forward(self, x):
net = self.net
feats = {1: x}
sf = self.scale_factor
if self.high_res:
x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
x = net.conv1(x)
x = net.bn1(x)
x = net.relu(x)
feats[2] = (
x
if not self.high_res
else F.interpolate(
x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
)
)
x = net.maxpool(x)
x = net.layer1(x)
feats[4] = (
x
if not self.high_res
else F.interpolate(
x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
)
)
x = net.layer2(x)
feats[8] = (
x
if not self.high_res
else F.interpolate(
x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
)
)
x = net.layer3(x)
feats[16] = (
x
if not self.high_res
else F.interpolate(
x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
)
)
x = net.layer4(x)
feats[32] = (
x
if not self.high_res
else F.interpolate(
x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
)
)
return feats
def train(self, mode=True):
super().train(mode)
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
pass
class WideResNet50(nn.Module):
def __init__(self, pretrained=False, high_res=False, weights=None) -> None:
super().__init__()
if weights is not None:
self.net = tvm.wide_resnet50_2(weights=weights)
else:
self.net = tvm.wide_resnet50_2(pretrained=pretrained)
self.high_res = high_res
self.scale_factor = 1 if not high_res else 1.5
def forward(self, x):
net = self.net
feats = {1: x}
sf = self.scale_factor
if self.high_res:
x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
x = net.conv1(x)
x = net.bn1(x)
x = net.relu(x)
feats[2] = (
x
if not self.high_res
else F.interpolate(
x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
)
)
x = net.maxpool(x)
x = net.layer1(x)
feats[4] = (
x
if not self.high_res
else F.interpolate(
x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
)
)
x = net.layer2(x)
feats[8] = (
x
if not self.high_res
else F.interpolate(
x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
)
)
x = net.layer3(x)
feats[16] = (
x
if not self.high_res
else F.interpolate(
x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
)
)
x = net.layer4(x)
feats[32] = (
x
if not self.high_res
else F.interpolate(
x, scale_factor=1 / sf, align_corners=False, mode="bilinear"
)
)
return feats
def train(self, mode=True):
super().train(mode)
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
pass