Spaces:
Running
Running
File size: 2,362 Bytes
3ad39d6 9b1a727 3ad39d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import sys
import argparse
import torch
import warnings
import numpy as np
from pathlib import Path
from torchvision.transforms import ToPILImage
from ..utils.base_model import BaseModel
sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
from COTR.utils import utils as utils_cotr
from COTR.models import build_model
from COTR.options.options import *
from COTR.options.options_utils import *
from COTR.inference.inference_helper import triangulate_corr
from COTR.inference.sparse_engine import SparseEngine
utils_cotr.fix_randomness(0)
torch.set_grad_enabled(False)
cotr_path = Path(__file__).parent / "../../third_party/COTR"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class COTR(BaseModel):
default_conf = {
"weights": "out/default",
"match_threshold": 0.2,
"max_keypoints": -1,
}
required_inputs = ["image0", "image1"]
def _init(self, conf):
parser = argparse.ArgumentParser()
set_COTR_arguments(parser)
opt = parser.parse_args()
opt.command = " ".join(sys.argv)
opt.load_weights_path = str(
cotr_path / conf["weights"] / "checkpoint.pth.tar"
)
layer_2_channels = {
"layer1": 256,
"layer2": 512,
"layer3": 1024,
"layer4": 2048,
}
opt.dim_feedforward = layer_2_channels[opt.layer]
model = build_model(opt)
model = model.to(device)
weights = torch.load(opt.load_weights_path, map_location="cpu")[
"model_state_dict"
]
utils_cotr.safe_load_weights(model, weights)
self.net = model.eval()
self.to_pil_func = ToPILImage(mode="RGB")
def _forward(self, data):
img_a = np.array(self.to_pil_func(data["image0"][0].cpu()))
img_b = np.array(self.to_pil_func(data["image1"][0].cpu()))
corrs = SparseEngine(
self.net, 32, mode="tile"
).cotr_corr_multiscale_with_cycle_consistency(
img_a,
img_b,
np.linspace(0.5, 0.0625, 4),
1,
max_corrs=self.conf["max_keypoints"],
queries_a=None,
)
pred = {
"keypoints0": torch.from_numpy(corrs[:, :2]),
"keypoints1": torch.from_numpy(corrs[:, 2:]),
}
return pred
|