Realcat
update:sift and update lightglue
d46c0a9
raw
history blame
1.5 kB
import torch
from kornia.color import rgb_to_grayscale
from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori
from .sift import SIFT
class DoGHardNet(SIFT):
required_data_keys = ["image"]
def __init__(self, **conf):
super().__init__(**conf)
self.laf_desc = LAFDescriptor(HardNet(True)).eval()
def forward(self, data: dict) -> dict:
image = data["image"]
if image.shape[1] == 3:
image = rgb_to_grayscale(image)
device = image.device
self.laf_desc = self.laf_desc.to(device)
self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
pred = []
if "image_size" in data.keys():
im_size = data.get("image_size").long()
else:
im_size = None
for k in range(len(image)):
img = image[k]
if im_size is not None:
w, h = data["image_size"][k]
img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
p = self.extract_single_image(img)
lafs = laf_from_center_scale_ori(
p["keypoints"].reshape(1, -1, 2),
6.0 * p["scales"].reshape(1, -1, 1, 1),
torch.rad2deg(p["oris"]).reshape(1, -1, 1),
).to(device)
p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
pred.append(p)
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
return pred