Realcat
init: files
4bde5d3
raw
history blame
2.33 kB
import argparse
import sys
from pathlib import Path
import numpy as np
import torch
from torchvision.transforms import ToPILImage
from ..utils.base_model import BaseModel
sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
from COTR.inference.sparse_engine import SparseEngine
from COTR.models import build_model
from COTR.options.options import * # noqa: F403
from COTR.options.options_utils import * # noqa: F403
from COTR.utils import utils as utils_cotr
utils_cotr.fix_randomness(0)
torch.set_grad_enabled(False)
cotr_path = Path(__file__).parent / "../../third_party/COTR"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class COTR(BaseModel):
default_conf = {
"weights": "out/default",
"match_threshold": 0.2,
"max_keypoints": -1,
}
required_inputs = ["image0", "image1"]
def _init(self, conf):
parser = argparse.ArgumentParser()
set_COTR_arguments(parser) # noqa: F405
opt = parser.parse_args()
opt.command = " ".join(sys.argv)
opt.load_weights_path = str(
cotr_path / conf["weights"] / "checkpoint.pth.tar"
)
layer_2_channels = {
"layer1": 256,
"layer2": 512,
"layer3": 1024,
"layer4": 2048,
}
opt.dim_feedforward = layer_2_channels[opt.layer]
model = build_model(opt)
model = model.to(device)
weights = torch.load(opt.load_weights_path, map_location="cpu")[
"model_state_dict"
]
utils_cotr.safe_load_weights(model, weights)
self.net = model.eval()
self.to_pil_func = ToPILImage(mode="RGB")
def _forward(self, data):
img_a = np.array(self.to_pil_func(data["image0"][0].cpu()))
img_b = np.array(self.to_pil_func(data["image1"][0].cpu()))
corrs = SparseEngine(
self.net, 32, mode="tile"
).cotr_corr_multiscale_with_cycle_consistency(
img_a,
img_b,
np.linspace(0.5, 0.0625, 4),
1,
max_corrs=self.conf["max_keypoints"],
queries_a=None,
)
pred = {
"keypoints0": torch.from_numpy(corrs[:, :2]),
"keypoints1": torch.from_numpy(corrs[:, 2:]),
}
return pred