|
import torch |
|
import torch.nn as nn |
|
import torch.nn.init as init |
|
|
|
from .nets.backbone import HourglassBackbone, SuperpointBackbone |
|
from .nets.junction_decoder import SuperpointDecoder |
|
from .nets.heatmap_decoder import PixelShuffleDecoder |
|
from .nets.descriptor_decoder import SuperpointDescriptor |
|
|
|
|
|
def get_model(model_cfg=None, loss_weights=None, mode="train"): |
|
""" Get model based on the model configuration. """ |
|
|
|
if model_cfg is None: |
|
raise ValueError("[Error] The model config is required!") |
|
|
|
|
|
print("\n\n\t--------Initializing model----------") |
|
supported_arch = ["simple"] |
|
if not model_cfg["model_architecture"] in supported_arch: |
|
raise ValueError( |
|
"[Error] The model architecture is not in supported arch!") |
|
|
|
if model_cfg["model_architecture"] == "simple": |
|
model = SOLD2Net(model_cfg) |
|
else: |
|
raise ValueError( |
|
"[Error] The model architecture is not in supported arch!") |
|
|
|
|
|
if mode == "train": |
|
if loss_weights is not None: |
|
for param_name, param in loss_weights.items(): |
|
if isinstance(param, nn.Parameter): |
|
print("\t [Debug] Adding %s with value %f to model" |
|
% (param_name, param.item())) |
|
model.register_parameter(param_name, param) |
|
else: |
|
raise ValueError( |
|
"[Error] the loss weights can not be None in dynamic weighting mode during training.") |
|
|
|
|
|
print("\tModel architecture: %s" % model_cfg["model_architecture"]) |
|
print("\tBackbone: %s" % model_cfg["backbone"]) |
|
print("\tJunction decoder: %s" % model_cfg["junction_decoder"]) |
|
print("\tHeatmap decoder: %s" % model_cfg["heatmap_decoder"]) |
|
print("\t-------------------------------------") |
|
|
|
return model |
|
|
|
|
|
class SOLD2Net(nn.Module): |
|
""" Full network for SOLD². """ |
|
def __init__(self, model_cfg): |
|
super(SOLD2Net, self).__init__() |
|
self.name = model_cfg["model_name"] |
|
self.cfg = model_cfg |
|
|
|
|
|
self.supported_backbone = ["lcnn", "superpoint"] |
|
self.backbone_net, self.feat_channel = self.get_backbone() |
|
|
|
|
|
self.supported_junction_decoder = ["superpoint_decoder"] |
|
self.junction_decoder = self.get_junction_decoder() |
|
|
|
|
|
self.supported_heatmap_decoder = ["pixel_shuffle", |
|
"pixel_shuffle_single"] |
|
self.heatmap_decoder = self.get_heatmap_decoder() |
|
|
|
|
|
if "descriptor_decoder" in self.cfg: |
|
self.supported_descriptor_decoder = ["superpoint_descriptor"] |
|
self.descriptor_decoder = self.get_descriptor_decoder() |
|
|
|
|
|
self.apply(weight_init) |
|
|
|
def forward(self, input_images): |
|
|
|
features = self.backbone_net(input_images) |
|
|
|
|
|
junctions = self.junction_decoder(features) |
|
|
|
|
|
heatmaps = self.heatmap_decoder(features) |
|
|
|
outputs = {"junctions": junctions, "heatmap": heatmaps} |
|
|
|
|
|
if "descriptor_decoder" in self.cfg: |
|
outputs["descriptors"] = self.descriptor_decoder(features) |
|
|
|
return outputs |
|
|
|
def get_backbone(self): |
|
""" Retrieve the backbone encoder network. """ |
|
if not self.cfg["backbone"] in self.supported_backbone: |
|
raise ValueError( |
|
"[Error] The backbone selection is not supported.") |
|
|
|
|
|
if self.cfg["backbone"] == "lcnn": |
|
backbone_cfg = self.cfg["backbone_cfg"] |
|
backbone = HourglassBackbone(**backbone_cfg) |
|
feat_channel = 256 |
|
|
|
elif self.cfg["backbone"] == "superpoint": |
|
backbone_cfg = self.cfg["backbone_cfg"] |
|
backbone = SuperpointBackbone() |
|
feat_channel = 128 |
|
|
|
else: |
|
raise ValueError( |
|
"[Error] The backbone selection is not supported.") |
|
|
|
return backbone, feat_channel |
|
|
|
def get_junction_decoder(self): |
|
""" Get the junction decoder. """ |
|
if (not self.cfg["junction_decoder"] |
|
in self.supported_junction_decoder): |
|
raise ValueError( |
|
"[Error] The junction decoder selection is not supported.") |
|
|
|
|
|
if self.cfg["junction_decoder"] == "superpoint_decoder": |
|
decoder = SuperpointDecoder(self.feat_channel, |
|
self.cfg["backbone"]) |
|
else: |
|
raise ValueError( |
|
"[Error] The junction decoder selection is not supported.") |
|
|
|
return decoder |
|
|
|
def get_heatmap_decoder(self): |
|
""" Get the heatmap decoder. """ |
|
if not self.cfg["heatmap_decoder"] in self.supported_heatmap_decoder: |
|
raise ValueError( |
|
"[Error] The heatmap decoder selection is not supported.") |
|
|
|
|
|
if self.cfg["heatmap_decoder"] == "pixel_shuffle": |
|
if self.cfg["backbone"] == "lcnn": |
|
decoder = PixelShuffleDecoder(self.feat_channel, |
|
num_upsample=2) |
|
elif self.cfg["backbone"] == "superpoint": |
|
decoder = PixelShuffleDecoder(self.feat_channel, |
|
num_upsample=3) |
|
else: |
|
raise ValueError("[Error] Unknown backbone option.") |
|
|
|
elif self.cfg["heatmap_decoder"] == "pixel_shuffle_single": |
|
if self.cfg["backbone"] == "lcnn": |
|
decoder = PixelShuffleDecoder( |
|
self.feat_channel, num_upsample=2, output_channel=1) |
|
elif self.cfg["backbone"] == "superpoint": |
|
decoder = PixelShuffleDecoder( |
|
self.feat_channel, num_upsample=3, output_channel=1) |
|
else: |
|
raise ValueError("[Error] Unknown backbone option.") |
|
else: |
|
raise ValueError( |
|
"[Error] The heatmap decoder selection is not supported.") |
|
|
|
return decoder |
|
|
|
def get_descriptor_decoder(self): |
|
""" Get the descriptor decoder. """ |
|
if (not self.cfg["descriptor_decoder"] |
|
in self.supported_descriptor_decoder): |
|
raise ValueError( |
|
"[Error] The descriptor decoder selection is not supported.") |
|
|
|
|
|
if self.cfg["descriptor_decoder"] == "superpoint_descriptor": |
|
decoder = SuperpointDescriptor(self.feat_channel) |
|
else: |
|
raise ValueError( |
|
"[Error] The descriptor decoder selection is not supported.") |
|
|
|
return decoder |
|
|
|
|
|
def weight_init(m): |
|
""" Weight initialization function. """ |
|
|
|
if isinstance(m, nn.Conv2d): |
|
init.xavier_normal_(m.weight.data) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data) |
|
|
|
elif isinstance(m, nn.BatchNorm2d): |
|
init.normal_(m.weight.data, mean=1, std=0.02) |
|
init.constant_(m.bias.data, 0) |
|
|
|
elif isinstance(m, nn.Linear): |
|
init.xavier_normal_(m.weight.data) |
|
init.normal_(m.bias.data) |
|
else: |
|
pass |
|
|