Spaces:
Build error
Build error
import torch | |
from torch import nn | |
from torchvision.models import densenet169 | |
from config.finetune_config import set_args | |
args = set_args() | |
class Classifier(nn.Module): | |
def __init__(self, num_classes): | |
super(Classifier, self).__init__() | |
self.GDConv1 = nn.Conv2d(1664 * 2, 1024, kernel_size=4, padding=0, dilation=2) | |
self.GDConv2 = nn.Conv2d(1664 * 2, 1024, kernel_size=5, padding=1, dilation=2) | |
self.GDConv3 = nn.Conv2d(1664 * 2, 1024, kernel_size=3, padding=0, dilation=3) | |
self.LN1 = nn.LayerNorm([1024, 1, 1]) | |
self.LN2 = nn.LayerNorm([1024, 1, 1]) | |
self.LN3 = nn.LayerNorm([1024, 1, 1]) | |
self.gelu = nn.GELU() | |
self.fc_dropout = nn.Dropout(0.2) | |
self.fc = nn.Linear(1024 * 3, num_classes) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight) | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.Linear): | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
x1 = self.GDConv1(x) | |
x1 = self.LN1(x1) | |
x1 = x1.view(x1.size(0), -1) | |
x2 = self.GDConv2(x) | |
x2 = self.LN2(x2) | |
x2 = x2.view(x2.size(0), -1) | |
x3 = self.GDConv3(x) | |
x3 = self.LN3(x3) | |
x3 = x3.view(x3.size(0), -1) | |
X = torch.cat((x1, x2, x3), 1) | |
X = self.gelu(X) | |
output = self.fc(self.fc_dropout(X)) | |
return output | |
class M_DenseNet(nn.Module): | |
def __init__(self, pretrain='IN', num_classes=8): | |
super(M_DenseNet, self).__init__() | |
# feature layer | |
if pretrain == 'IN': | |
model = densenet169(pretrained=True) # 此处的model参数是已经加载了预训练参数的模型 | |
self.feature = nn.Sequential(*list(model.children())[:-1]) | |
else: | |
model = torch.load(args.finetune_path) | |
self.feature = nn.Sequential(*list(model.children())[:-2]) | |
self.classifier = Classifier(num_classes) | |
def forward(self, left, right): | |
left = self.feature(left) | |
right = self.feature(right) | |
x = torch.cat((left, right), 1) | |
X = self.classifier(x) | |
return X | |
if __name__ == '__main__': | |
model = M_DenseNet() | |
input1 = torch.normal(0, 1, size=(4, 3, 224, 224)) | |
input2 = torch.normal(0, 1, size=(4, 3, 224, 224)) | |
output = model(input1, input2) | |
print(output) | |