from models.networks.modules import * class BaseNet(nn.Module): def __init__(self, in_channels=1, out_channels=1, norm=True): super(BaseNet, self).__init__() self.n_channels = in_channels self.n_classes = out_channels self.inc = DoubleConv(in_channels, 32, norm=norm) self.down1 = Down(32, 64, norm=norm) self.down2 = Down(64, 128, norm=norm) self.down3 = Down(128, 128, norm=norm) self.up1 = Up(256, 64, bilinear=True, norm=norm) self.up2 = Up(128, 32, bilinear=True, norm=norm) self.up3 = Up(64, 32, bilinear=True, norm=norm) self.outc = OutConv(32, out_channels) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x = self.up1(x4, x3) x = self.up2(x, x2) x = self.up3(x, x1) logits = self.outc(x) return logits class IAN(BaseNet): def __init__(self, in_channels=1, out_channels=1, norm=True): super(IAN, self).__init__(in_channels, out_channels, norm) class ANSN(BaseNet): def __init__(self, in_channels=1, out_channels=1, norm=True): super(ANSN, self).__init__(in_channels, out_channels, norm) self.outc = OutConv(32, out_channels, act=False) class FuseNet(nn.Module): def __init__(self, in_channels=1, out_channels=1, norm=False): super(FuseNet, self).__init__() self.inc = AttentiveDoubleConv(in_channels, 32, norm=norm, leaky=False) self.down1 = AttentiveDown(32, 64, norm=norm, leaky=False) self.down2 = AttentiveDown(64, 64, norm=norm, leaky=False) self.up1 = AttentiveUp(128, 32, bilinear=True, norm=norm, leaky=False) self.up2 = AttentiveUp(64, 32, bilinear=True, norm=norm, leaky=False) self.outc = OutConv(32, out_channels) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x = self.up1(x3, x2) x = self.up2(x, x1) logits = self.outc(x) return logits if __name__ == '__main__': for key in FuseNet(4, 2).state_dict().keys(): print(key)