Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as tvm | |
class ResNet18(nn.Module): | |
def __init__(self, pretrained=False) -> None: | |
super().__init__() | |
self.net = tvm.resnet18(pretrained=pretrained) | |
def forward(self, x): | |
self = self.net | |
x1 = x | |
x = self.conv1(x1) | |
x = self.bn1(x) | |
x2 = self.relu(x) | |
x = self.maxpool(x2) | |
x4 = self.layer1(x) | |
x8 = self.layer2(x4) | |
x16 = self.layer3(x8) | |
x32 = self.layer4(x16) | |
return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1} | |
def train(self, mode=True): | |
super().train(mode) | |
for m in self.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
pass | |
class ResNet50(nn.Module): | |
def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None: | |
super().__init__() | |
if dilation is None: | |
dilation = [False,False,False] | |
if anti_aliased: | |
pass | |
else: | |
if weights is not None: | |
self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation) | |
else: | |
self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation) | |
del self.net.fc | |
self.high_res = high_res | |
self.freeze_bn = freeze_bn | |
def forward(self, x): | |
net = self.net | |
feats = {1:x} | |
x = net.conv1(x) | |
x = net.bn1(x) | |
x = net.relu(x) | |
feats[2] = x | |
x = net.maxpool(x) | |
x = net.layer1(x) | |
feats[4] = x | |
x = net.layer2(x) | |
feats[8] = x | |
x = net.layer3(x) | |
feats[16] = x | |
x = net.layer4(x) | |
feats[32] = x | |
return feats | |
def train(self, mode=True): | |
super().train(mode) | |
if self.freeze_bn: | |
for m in self.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
pass | |
class ResNet101(nn.Module): | |
def __init__(self, pretrained=False, high_res = False, weights = None) -> None: | |
super().__init__() | |
if weights is not None: | |
self.net = tvm.resnet101(weights = weights) | |
else: | |
self.net = tvm.resnet101(pretrained=pretrained) | |
self.high_res = high_res | |
self.scale_factor = 1 if not high_res else 1.5 | |
def forward(self, x): | |
net = self.net | |
feats = {1:x} | |
sf = self.scale_factor | |
if self.high_res: | |
x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") | |
x = net.conv1(x) | |
x = net.bn1(x) | |
x = net.relu(x) | |
feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") | |
x = net.maxpool(x) | |
x = net.layer1(x) | |
feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") | |
x = net.layer2(x) | |
feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") | |
x = net.layer3(x) | |
feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") | |
x = net.layer4(x) | |
feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") | |
return feats | |
def train(self, mode=True): | |
super().train(mode) | |
for m in self.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
pass | |
class WideResNet50(nn.Module): | |
def __init__(self, pretrained=False, high_res = False, weights = None) -> None: | |
super().__init__() | |
if weights is not None: | |
self.net = tvm.wide_resnet50_2(weights = weights) | |
else: | |
self.net = tvm.wide_resnet50_2(pretrained=pretrained) | |
self.high_res = high_res | |
self.scale_factor = 1 if not high_res else 1.5 | |
def forward(self, x): | |
net = self.net | |
feats = {1:x} | |
sf = self.scale_factor | |
if self.high_res: | |
x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic") | |
x = net.conv1(x) | |
x = net.bn1(x) | |
x = net.relu(x) | |
feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") | |
x = net.maxpool(x) | |
x = net.layer1(x) | |
feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") | |
x = net.layer2(x) | |
feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") | |
x = net.layer3(x) | |
feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") | |
x = net.layer4(x) | |
feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear") | |
return feats | |
def train(self, mode=True): | |
super().train(mode) | |
for m in self.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |
pass |