Spaces:
Running
Running
File size: 3,970 Bytes
4d4dd90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
"""
A two-view sparse feature matching pipeline.
This model contains sub-models for each step:
feature extraction, feature matching, outlier filtering, pose estimation.
Each step is optional, and the features or matches can be provided as input.
Default: SuperPoint with nearest neighbor matching.
Convention for the matches: m0[i] is the index of the keypoint in image 1
that corresponds to the keypoint i in image 0. m0[i] = -1 if i is unmatched.
"""
from omegaconf import OmegaConf
from . import get_model
from .base_model import BaseModel
to_ctr = OmegaConf.to_container # convert DictConfig to dict
class TwoViewPipeline(BaseModel):
default_conf = {
"extractor": {
"name": None,
"trainable": False,
},
"matcher": {"name": None},
"filter": {"name": None},
"solver": {"name": None},
"ground_truth": {"name": None},
"allow_no_extract": False,
"run_gt_in_forward": False,
}
required_data_keys = ["view0", "view1"]
strict_conf = False # need to pass new confs to children models
components = [
"extractor",
"matcher",
"filter",
"solver",
"ground_truth",
]
def _init(self, conf):
if conf.extractor.name:
self.extractor = get_model(conf.extractor.name)(to_ctr(conf.extractor))
if conf.matcher.name:
self.matcher = get_model(conf.matcher.name)(to_ctr(conf.matcher))
if conf.filter.name:
self.filter = get_model(conf.filter.name)(to_ctr(conf.filter))
if conf.solver.name:
self.solver = get_model(conf.solver.name)(to_ctr(conf.solver))
if conf.ground_truth.name:
self.ground_truth = get_model(conf.ground_truth.name)(
to_ctr(conf.ground_truth)
)
def extract_view(self, data, i):
data_i = data[f"view{i}"]
pred_i = data_i.get("cache", {})
skip_extract = len(pred_i) > 0 and self.conf.allow_no_extract
if self.conf.extractor.name and not skip_extract:
pred_i = {**pred_i, **self.extractor(data_i)}
elif self.conf.extractor.name and not self.conf.allow_no_extract:
pred_i = {**pred_i, **self.extractor({**data_i, **pred_i})}
return pred_i
def _forward(self, data):
pred0 = self.extract_view(data, "0")
pred1 = self.extract_view(data, "1")
pred = {
**{k + "0": v for k, v in pred0.items()},
**{k + "1": v for k, v in pred1.items()},
}
if self.conf.matcher.name:
pred = {**pred, **self.matcher({**data, **pred})}
if self.conf.filter.name:
pred = {**pred, **self.filter({**data, **pred})}
if self.conf.solver.name:
pred = {**pred, **self.solver({**data, **pred})}
if self.conf.ground_truth.name and self.conf.run_gt_in_forward:
gt_pred = self.ground_truth({**data, **pred})
pred.update({f"gt_{k}": v for k, v in gt_pred.items()})
return pred
def loss(self, pred, data):
losses = {}
metrics = {}
total = 0
# get labels
if self.conf.ground_truth.name and not self.conf.run_gt_in_forward:
gt_pred = self.ground_truth({**data, **pred})
pred.update({f"gt_{k}": v for k, v in gt_pred.items()})
for k in self.components:
apply = True
if "apply_loss" in self.conf[k].keys():
apply = self.conf[k].apply_loss
if self.conf[k].name and apply:
try:
losses_, metrics_ = getattr(self, k).loss(pred, {**pred, **data})
except NotImplementedError:
continue
losses = {**losses, **losses_}
metrics = {**metrics, **metrics_}
total = losses_["total"] + total
return {**losses, "total": total}, metrics
|