diff --git a/hloc/extractors/alike.py b/hloc/extractors/alike.py index 1f6bf1fa1f5fbaffca7c0101fdca6d2756d5aa3b..4da3ca3acd734e4add9ae2883b6e06515ad57ad2 100644 --- a/hloc/extractors/alike.py +++ b/hloc/extractors/alike.py @@ -36,13 +36,13 @@ class Alike(BaseModel): ), ) logger.info("Loaded Alike model from {}".format(model_path)) + configs[conf["model_name"]]["model_path"] = model_path self.net = Alike_( **configs[conf["model_name"]], device=device, top_k=conf["top_k"], scores_th=conf["detection_threshold"], n_limit=conf["max_keypoints"], - model_path=model_path, ) logger.info("Load Alike model done.") diff --git a/hloc/extractors/darkfeat.py b/hloc/extractors/darkfeat.py index 643802c1ea4b997b6ae27d35e9b17350d19e4631..5c5d34d04bc59ece13c35a8363e7cf563b7709c7 100644 --- a/hloc/extractors/darkfeat.py +++ b/hloc/extractors/darkfeat.py @@ -23,7 +23,7 @@ class DarkFeat(BaseModel): def _init(self, conf): model_path = self._download_model( repo_id=MODEL_REPO_ID, - filename="{}/{}.pth".format( + filename="{}/{}".format( Path(__file__).stem, self.conf["model_name"] ), ) diff --git a/hloc/extractors/lanet.py b/hloc/extractors/lanet.py index 6d96af78bd7e3e5fae7b73808fd3e6df55a9dc87..7869c40ad70f82f5fe1e3c506c20e58c1c4780e2 100644 --- a/hloc/extractors/lanet.py +++ b/hloc/extractors/lanet.py @@ -33,10 +33,6 @@ class LANet(BaseModel): Path(__file__).stem, self.conf["model_name"] ), ) - if not model_path.exists(): - logger.warning( - f"No model found at {model_path}, please download it first." - ) self.net = PointModel(is_test=True) state_dict = torch.load(model_path, map_location="cpu") self.net.load_state_dict(state_dict["model_state"]) diff --git a/hloc/extractors/sfd2.py b/hloc/extractors/sfd2.py index 9a1979bb9e9ee35130bfe157d06b8a160d9e5e68..4a4ea50f52793d07c05ef79c71ded477e85ff14f 100644 --- a/hloc/extractors/sfd2.py +++ b/hloc/extractors/sfd2.py @@ -27,7 +27,7 @@ class SFD2(BaseModel): model_path = self._download_model( repo_id=MODEL_REPO_ID, filename="{}/{}".format( - Path(__file__).stem, self.conf["model_name"] + "pram", self.conf["model_name"] ), ) self.net = load_sfd2(weight_path=model_path).eval() diff --git a/hloc/match_dense.py b/hloc/match_dense.py index 30d35422cf085744f3436da3c321edbb1ae69b4f..0335486babd4206fb963db4d43005bb26b25ad8e 100644 --- a/hloc/match_dense.py +++ b/hloc/match_dense.py @@ -257,6 +257,7 @@ confs = { "model": { "name": "roma", "weights": "outdoor", + "model_name": "roma_outdoor.pth", "max_keypoints": 2000, "match_threshold": 0.2, }, @@ -273,7 +274,7 @@ confs = { "output": "matches-gim", "model": { "name": "gim", - "weights": "gim_dkm_100h.ckpt", + "model_name": "gim_dkm_100h.ckpt", "max_keypoints": 2000, "match_threshold": 0.2, }, diff --git a/hloc/matchers/gim.py b/hloc/matchers/gim.py index 158cf99195f18782c97c08842ed5438f9d9122c2..9d397003f6ae1d76a13738868cca68e6ff05f52a 100644 --- a/hloc/matchers/gim.py +++ b/hloc/matchers/gim.py @@ -3,46 +3,116 @@ from pathlib import Path import torch -from .. import MODEL_REPO_ID, logger +from .. import MODEL_REPO_ID, logger, DEVICE from ..utils.base_model import BaseModel gim_path = Path(__file__).parent / "../../third_party/gim" sys.path.append(str(gim_path)) -from dkm.models.model_zoo.DKMv3 import DKMv3 +def load_model(weight_name, checkpoints_path): + # load model + model = None + detector = None + if weight_name == "gim_dkm": + from gim.dkm.models.model_zoo.DKMv3 import DKMv3 + model = DKMv3(weights=None, h=672, w=896) + elif weight_name == "gim_loftr": + from gim.loftr.loftr import LoFTR + from gim.loftr.misc import lower_config + from gim.loftr.config import get_cfg_defaults + + model = LoFTR(lower_config(get_cfg_defaults())["loftr"]) + elif weight_name == "gim_lightglue": + from gim.lightglue.superpoint import SuperPoint + from gim.lightglue.models.matchers.lightglue import LightGlue + + detector = SuperPoint( + { + "max_num_keypoints": 2048, + "force_num_keypoints": True, + "detection_threshold": 0.0, + "nms_radius": 3, + "trainable": False, + } + ) + model = LightGlue( + { + "filter_threshold": 0.1, + "flash": False, + "checkpointed": True, + } + ) + + # load state dict + if weight_name == "gim_dkm": + state_dict = torch.load(checkpoints_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + for k in list(state_dict.keys()): + if k.startswith("model."): + state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) + if "encoder.net.fc" in k: + state_dict.pop(k) + model.load_state_dict(state_dict) + + elif weight_name == "gim_loftr": + state_dict = torch.load(checkpoints_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + model.load_state_dict(state_dict) + + elif weight_name == "gim_lightglue": + state_dict = torch.load(checkpoints_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + for k in list(state_dict.keys()): + if k.startswith("model."): + state_dict.pop(k) + if k.startswith("superpoint."): + state_dict[k.replace("superpoint.", "", 1)] = state_dict.pop(k) + detector.load_state_dict(state_dict) + + state_dict = torch.load(checkpoints_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + for k in list(state_dict.keys()): + if k.startswith("superpoint."): + state_dict.pop(k) + if k.startswith("model."): + state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) + model.load_state_dict(state_dict) + + # eval mode + if detector is not None: + detector = detector.eval().to(DEVICE) + model = model.eval().to(DEVICE) + return model class GIM(BaseModel): default_conf = { - "model_name": "gim_lightglue_100h.ckpt", "match_threshold": 0.2, "checkpoint_dir": gim_path / "weights", + "weights": "gim_dkm", } required_inputs = [ "image0", "image1", ] + ckpt_name_dict = { + "gim_dkm": "gim_dkm_100h.ckpt", + "gim_loftr": "gim_loftr_50h.ckpt", + "gim_lightglue": "gim_lightglue_100h.ckpt", + } def _init(self, conf): + ckpt_name = self.ckpt_name_dict[conf["weights"]] model_path = self._download_model( repo_id=MODEL_REPO_ID, - filename="{}/{}".format( - Path(__file__).stem, self.conf["model_name"] - ), + filename="{}/{}".format(Path(__file__).stem, ckpt_name), ) - self.aspect_ratio = 896 / 672 - model = DKMv3(None, 672, 896, upsample_preds=True) - state_dict = torch.load(str(model_path), map_location="cpu") - if "state_dict" in state_dict.keys(): - state_dict = state_dict["state_dict"] - for k in list(state_dict.keys()): - if k.startswith("model."): - state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) - if "encoder.net.fc" in k: - state_dict.pop(k) - model.load_state_dict(state_dict) - + model = load_model(conf["weights"], model_path) self.net = model logger.info("Loaded GIM model") @@ -94,6 +164,7 @@ class GIM(BaseModel): return mask def _forward(self, data): + # TODO: only support dkm+gim image0, image1 = self.pad_image( data["image0"], self.aspect_ratio ), self.pad_image(data["image1"], self.aspect_ratio) diff --git a/hloc/matchers/imp.py b/hloc/matchers/imp.py index e7197a58c48bf467ec5fda777637d1dd40c6b9e5..1867870c24a382fe1e3bbd4251969797dd5cb2f9 100644 --- a/hloc/matchers/imp.py +++ b/hloc/matchers/imp.py @@ -34,7 +34,7 @@ class IMP(BaseModel): model_path = self._download_model( repo_id=MODEL_REPO_ID, filename="{}/{}".format( - Path(__file__).stem, self.conf["model_name"] + 'pram', self.conf["model_name"] ), ) diff --git a/hloc/matchers/lightglue.py b/hloc/matchers/lightglue.py index 95947e6c3cc3d87e51a437a0dfe16f6f8aa52cbf..975b55485276975f12f18aefb9f71727c9b5aa22 100644 --- a/hloc/matchers/lightglue.py +++ b/hloc/matchers/lightglue.py @@ -33,6 +33,7 @@ class LightGlue(BaseModel): ] def _init(self, conf): + logger.info("Loading lightglue model, {}".format(conf["model_name"])) model_path = self._download_model( repo_id=MODEL_REPO_ID, filename="{}/{}".format( diff --git a/third_party/gim/gim/__init__.py b/third_party/gim/gim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13b83c33239c952cea7ed746cf043a26c46e5109 --- /dev/null +++ b/third_party/gim/gim/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun diff --git a/third_party/gim/dkm/__init__.py b/third_party/gim/gim/dkm/__init__.py similarity index 100% rename from third_party/gim/dkm/__init__.py rename to third_party/gim/gim/dkm/__init__.py diff --git a/third_party/gim/dkm/benchmarks/__init__.py b/third_party/gim/gim/dkm/benchmarks/__init__.py similarity index 100% rename from third_party/gim/dkm/benchmarks/__init__.py rename to third_party/gim/gim/dkm/benchmarks/__init__.py diff --git a/third_party/gim/dkm/benchmarks/hpatches_sequences_homog_benchmark.py b/third_party/gim/gim/dkm/benchmarks/hpatches_sequences_homog_benchmark.py similarity index 100% rename from third_party/gim/dkm/benchmarks/hpatches_sequences_homog_benchmark.py rename to third_party/gim/gim/dkm/benchmarks/hpatches_sequences_homog_benchmark.py diff --git a/third_party/gim/dkm/benchmarks/megadepth1500_benchmark.py b/third_party/gim/gim/dkm/benchmarks/megadepth1500_benchmark.py similarity index 100% rename from third_party/gim/dkm/benchmarks/megadepth1500_benchmark.py rename to third_party/gim/gim/dkm/benchmarks/megadepth1500_benchmark.py diff --git a/third_party/gim/dkm/benchmarks/megadepth_dense_benchmark.py b/third_party/gim/gim/dkm/benchmarks/megadepth_dense_benchmark.py similarity index 100% rename from third_party/gim/dkm/benchmarks/megadepth_dense_benchmark.py rename to third_party/gim/gim/dkm/benchmarks/megadepth_dense_benchmark.py diff --git a/third_party/gim/dkm/benchmarks/scannet_benchmark.py b/third_party/gim/gim/dkm/benchmarks/scannet_benchmark.py similarity index 100% rename from third_party/gim/dkm/benchmarks/scannet_benchmark.py rename to third_party/gim/gim/dkm/benchmarks/scannet_benchmark.py diff --git a/third_party/gim/dkm/checkpointing/__init__.py b/third_party/gim/gim/dkm/checkpointing/__init__.py similarity index 100% rename from third_party/gim/dkm/checkpointing/__init__.py rename to third_party/gim/gim/dkm/checkpointing/__init__.py diff --git a/third_party/gim/dkm/checkpointing/checkpoint.py b/third_party/gim/gim/dkm/checkpointing/checkpoint.py similarity index 100% rename from third_party/gim/dkm/checkpointing/checkpoint.py rename to third_party/gim/gim/dkm/checkpointing/checkpoint.py diff --git a/third_party/gim/dkm/datasets/__init__.py b/third_party/gim/gim/dkm/datasets/__init__.py similarity index 100% rename from third_party/gim/dkm/datasets/__init__.py rename to third_party/gim/gim/dkm/datasets/__init__.py diff --git a/third_party/gim/dkm/datasets/megadepth.py b/third_party/gim/gim/dkm/datasets/megadepth.py similarity index 100% rename from third_party/gim/dkm/datasets/megadepth.py rename to third_party/gim/gim/dkm/datasets/megadepth.py diff --git a/third_party/gim/dkm/datasets/scannet.py b/third_party/gim/gim/dkm/datasets/scannet.py similarity index 100% rename from third_party/gim/dkm/datasets/scannet.py rename to third_party/gim/gim/dkm/datasets/scannet.py diff --git a/third_party/gim/dkm/losses/__init__.py b/third_party/gim/gim/dkm/losses/__init__.py similarity index 100% rename from third_party/gim/dkm/losses/__init__.py rename to third_party/gim/gim/dkm/losses/__init__.py diff --git a/third_party/gim/dkm/losses/depth_match_regression_loss.py b/third_party/gim/gim/dkm/losses/depth_match_regression_loss.py similarity index 100% rename from third_party/gim/dkm/losses/depth_match_regression_loss.py rename to third_party/gim/gim/dkm/losses/depth_match_regression_loss.py diff --git a/third_party/gim/dkm/models/__init__.py b/third_party/gim/gim/dkm/models/__init__.py similarity index 100% rename from third_party/gim/dkm/models/__init__.py rename to third_party/gim/gim/dkm/models/__init__.py diff --git a/third_party/gim/dkm/models/dkm.py b/third_party/gim/gim/dkm/models/dkm.py similarity index 99% rename from third_party/gim/dkm/models/dkm.py rename to third_party/gim/gim/dkm/models/dkm.py index edf5641029e53866be80e679a4d71ae781348344..0cc6d35d5165c797ef7fdd4bc16ac405efa9b02f 100644 --- a/third_party/gim/dkm/models/dkm.py +++ b/third_party/gim/gim/dkm/models/dkm.py @@ -5,9 +5,9 @@ from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F -from dkm.utils import get_tuple_transform_ops +from gim.dkm.utils import get_tuple_transform_ops from einops import rearrange -from dkm.utils.local_correlation import local_correlation +from gim.dkm.utils.local_correlation import local_correlation class ConvRefiner(nn.Module): @@ -609,7 +609,7 @@ class RegressionMatcher(nn.Module): if "balanced" not in self.sample_mode: return good_matches, good_certainty - from dkm.utils.kde import kde + from gim.dkm.utils.kde import kde density = kde(good_matches, std=0.1) p = 1 / (density+1) p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones diff --git a/third_party/gim/dkm/models/encoders.py b/third_party/gim/gim/dkm/models/encoders.py similarity index 100% rename from third_party/gim/dkm/models/encoders.py rename to third_party/gim/gim/dkm/models/encoders.py diff --git a/third_party/gim/dkm/models/model_zoo/DKMv3.py b/third_party/gim/gim/dkm/models/model_zoo/DKMv3.py similarity index 98% rename from third_party/gim/dkm/models/model_zoo/DKMv3.py rename to third_party/gim/gim/dkm/models/model_zoo/DKMv3.py index 05285764d6f208cbd9a55c721caae91d04c25ecd..57f8a8bce35a8b499ece5c11ca42659f4197a95b 100644 --- a/third_party/gim/dkm/models/model_zoo/DKMv3.py +++ b/third_party/gim/gim/dkm/models/model_zoo/DKMv3.py @@ -1,8 +1,8 @@ import torch from torch import nn -from dkm.models.dkm import * -from dkm.models.encoders import * +from gim.dkm.models.dkm import * +from gim.dkm.models.encoders import * def DKMv3(weights, h, w, symmetric = True, sample_mode= "threshold_balanced", **kwargs): diff --git a/third_party/gim/dkm/models/model_zoo/__init__.py b/third_party/gim/gim/dkm/models/model_zoo/__init__.py similarity index 100% rename from third_party/gim/dkm/models/model_zoo/__init__.py rename to third_party/gim/gim/dkm/models/model_zoo/__init__.py diff --git a/third_party/gim/dkm/train/__init__.py b/third_party/gim/gim/dkm/train/__init__.py similarity index 100% rename from third_party/gim/dkm/train/__init__.py rename to third_party/gim/gim/dkm/train/__init__.py diff --git a/third_party/gim/dkm/train/train.py b/third_party/gim/gim/dkm/train/train.py similarity index 100% rename from third_party/gim/dkm/train/train.py rename to third_party/gim/gim/dkm/train/train.py diff --git a/third_party/gim/dkm/utils/__init__.py b/third_party/gim/gim/dkm/utils/__init__.py similarity index 100% rename from third_party/gim/dkm/utils/__init__.py rename to third_party/gim/gim/dkm/utils/__init__.py diff --git a/third_party/gim/dkm/utils/kde.py b/third_party/gim/gim/dkm/utils/kde.py similarity index 100% rename from third_party/gim/dkm/utils/kde.py rename to third_party/gim/gim/dkm/utils/kde.py diff --git a/third_party/gim/dkm/utils/local_correlation.py b/third_party/gim/gim/dkm/utils/local_correlation.py similarity index 100% rename from third_party/gim/dkm/utils/local_correlation.py rename to third_party/gim/gim/dkm/utils/local_correlation.py diff --git a/third_party/gim/dkm/utils/transforms.py b/third_party/gim/gim/dkm/utils/transforms.py similarity index 100% rename from third_party/gim/dkm/utils/transforms.py rename to third_party/gim/gim/dkm/utils/transforms.py diff --git a/third_party/gim/dkm/utils/utils.py b/third_party/gim/gim/dkm/utils/utils.py similarity index 100% rename from third_party/gim/dkm/utils/utils.py rename to third_party/gim/gim/dkm/utils/utils.py diff --git a/third_party/gim/gluefactory/__init__.py b/third_party/gim/gim/gluefactory/__init__.py similarity index 100% rename from third_party/gim/gluefactory/__init__.py rename to third_party/gim/gim/gluefactory/__init__.py diff --git a/third_party/gim/gluefactory/configs/aliked+NN.yaml b/third_party/gim/gim/gluefactory/configs/aliked+NN.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/aliked+NN.yaml rename to third_party/gim/gim/gluefactory/configs/aliked+NN.yaml diff --git a/third_party/gim/gluefactory/configs/aliked+lightglue-official.yaml b/third_party/gim/gim/gluefactory/configs/aliked+lightglue-official.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/aliked+lightglue-official.yaml rename to third_party/gim/gim/gluefactory/configs/aliked+lightglue-official.yaml diff --git a/third_party/gim/gluefactory/configs/aliked+lightglue_homography.yaml b/third_party/gim/gim/gluefactory/configs/aliked+lightglue_homography.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/aliked+lightglue_homography.yaml rename to third_party/gim/gim/gluefactory/configs/aliked+lightglue_homography.yaml diff --git a/third_party/gim/gluefactory/configs/aliked+lightglue_megadepth.yaml b/third_party/gim/gim/gluefactory/configs/aliked+lightglue_megadepth.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/aliked+lightglue_megadepth.yaml rename to third_party/gim/gim/gluefactory/configs/aliked+lightglue_megadepth.yaml diff --git a/third_party/gim/gluefactory/configs/disk+NN.yaml b/third_party/gim/gim/gluefactory/configs/disk+NN.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/disk+NN.yaml rename to third_party/gim/gim/gluefactory/configs/disk+NN.yaml diff --git a/third_party/gim/gluefactory/configs/disk+lightglue-official.yaml b/third_party/gim/gim/gluefactory/configs/disk+lightglue-official.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/disk+lightglue-official.yaml rename to third_party/gim/gim/gluefactory/configs/disk+lightglue-official.yaml diff --git a/third_party/gim/gluefactory/configs/disk+lightglue_homography.yaml b/third_party/gim/gim/gluefactory/configs/disk+lightglue_homography.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/disk+lightglue_homography.yaml rename to third_party/gim/gim/gluefactory/configs/disk+lightglue_homography.yaml diff --git a/third_party/gim/gluefactory/configs/disk+lightglue_megadepth.yaml b/third_party/gim/gim/gluefactory/configs/disk+lightglue_megadepth.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/disk+lightglue_megadepth.yaml rename to third_party/gim/gim/gluefactory/configs/disk+lightglue_megadepth.yaml diff --git a/third_party/gim/gluefactory/configs/sift+NN.yaml b/third_party/gim/gim/gluefactory/configs/sift+NN.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/sift+NN.yaml rename to third_party/gim/gim/gluefactory/configs/sift+NN.yaml diff --git a/third_party/gim/gluefactory/configs/sift+lightglue-official.yaml b/third_party/gim/gim/gluefactory/configs/sift+lightglue-official.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/sift+lightglue-official.yaml rename to third_party/gim/gim/gluefactory/configs/sift+lightglue-official.yaml diff --git a/third_party/gim/gluefactory/configs/sift+lightglue_homography.yaml b/third_party/gim/gim/gluefactory/configs/sift+lightglue_homography.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/sift+lightglue_homography.yaml rename to third_party/gim/gim/gluefactory/configs/sift+lightglue_homography.yaml diff --git a/third_party/gim/gluefactory/configs/sift+lightglue_megadepth.yaml b/third_party/gim/gim/gluefactory/configs/sift+lightglue_megadepth.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/sift+lightglue_megadepth.yaml rename to third_party/gim/gim/gluefactory/configs/sift+lightglue_megadepth.yaml diff --git a/third_party/gim/gluefactory/configs/superpoint+NN.yaml b/third_party/gim/gim/gluefactory/configs/superpoint+NN.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/superpoint+NN.yaml rename to third_party/gim/gim/gluefactory/configs/superpoint+NN.yaml diff --git a/third_party/gim/gluefactory/configs/superpoint+lightglue-official.yaml b/third_party/gim/gim/gluefactory/configs/superpoint+lightglue-official.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/superpoint+lightglue-official.yaml rename to third_party/gim/gim/gluefactory/configs/superpoint+lightglue-official.yaml diff --git a/third_party/gim/gluefactory/configs/superpoint+lightglue_homography.yaml b/third_party/gim/gim/gluefactory/configs/superpoint+lightglue_homography.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/superpoint+lightglue_homography.yaml rename to third_party/gim/gim/gluefactory/configs/superpoint+lightglue_homography.yaml diff --git a/third_party/gim/gluefactory/configs/superpoint+lightglue_megadepth.yaml b/third_party/gim/gim/gluefactory/configs/superpoint+lightglue_megadepth.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/superpoint+lightglue_megadepth.yaml rename to third_party/gim/gim/gluefactory/configs/superpoint+lightglue_megadepth.yaml diff --git a/third_party/gim/gluefactory/configs/superpoint+lsd+gluestick-homography.yaml b/third_party/gim/gim/gluefactory/configs/superpoint+lsd+gluestick-homography.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/superpoint+lsd+gluestick-homography.yaml rename to third_party/gim/gim/gluefactory/configs/superpoint+lsd+gluestick-homography.yaml diff --git a/third_party/gim/gluefactory/configs/superpoint+lsd+gluestick-megadepth.yaml b/third_party/gim/gim/gluefactory/configs/superpoint+lsd+gluestick-megadepth.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/superpoint+lsd+gluestick-megadepth.yaml rename to third_party/gim/gim/gluefactory/configs/superpoint+lsd+gluestick-megadepth.yaml diff --git a/third_party/gim/gluefactory/configs/superpoint+lsd+gluestick.yaml b/third_party/gim/gim/gluefactory/configs/superpoint+lsd+gluestick.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/superpoint+lsd+gluestick.yaml rename to third_party/gim/gim/gluefactory/configs/superpoint+lsd+gluestick.yaml diff --git a/third_party/gim/gluefactory/configs/superpoint+superglue-official.yaml b/third_party/gim/gim/gluefactory/configs/superpoint+superglue-official.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/superpoint+superglue-official.yaml rename to third_party/gim/gim/gluefactory/configs/superpoint+superglue-official.yaml diff --git a/third_party/gim/gluefactory/configs/superpoint-open+NN.yaml b/third_party/gim/gim/gluefactory/configs/superpoint-open+NN.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/superpoint-open+NN.yaml rename to third_party/gim/gim/gluefactory/configs/superpoint-open+NN.yaml diff --git a/third_party/gim/gluefactory/configs/superpoint-open+lightglue_homography.yaml b/third_party/gim/gim/gluefactory/configs/superpoint-open+lightglue_homography.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/superpoint-open+lightglue_homography.yaml rename to third_party/gim/gim/gluefactory/configs/superpoint-open+lightglue_homography.yaml diff --git a/third_party/gim/gluefactory/configs/superpoint-open+lightglue_megadepth.yaml b/third_party/gim/gim/gluefactory/configs/superpoint-open+lightglue_megadepth.yaml similarity index 100% rename from third_party/gim/gluefactory/configs/superpoint-open+lightglue_megadepth.yaml rename to third_party/gim/gim/gluefactory/configs/superpoint-open+lightglue_megadepth.yaml diff --git a/third_party/gim/gluefactory/datasets/__init__.py b/third_party/gim/gim/gluefactory/datasets/__init__.py similarity index 100% rename from third_party/gim/gluefactory/datasets/__init__.py rename to third_party/gim/gim/gluefactory/datasets/__init__.py diff --git a/third_party/gim/gluefactory/datasets/augmentations.py b/third_party/gim/gim/gluefactory/datasets/augmentations.py similarity index 100% rename from third_party/gim/gluefactory/datasets/augmentations.py rename to third_party/gim/gim/gluefactory/datasets/augmentations.py diff --git a/third_party/gim/gluefactory/datasets/base_dataset.py b/third_party/gim/gim/gluefactory/datasets/base_dataset.py similarity index 100% rename from third_party/gim/gluefactory/datasets/base_dataset.py rename to third_party/gim/gim/gluefactory/datasets/base_dataset.py diff --git a/third_party/gim/gluefactory/datasets/eth3d.py b/third_party/gim/gim/gluefactory/datasets/eth3d.py similarity index 100% rename from third_party/gim/gluefactory/datasets/eth3d.py rename to third_party/gim/gim/gluefactory/datasets/eth3d.py diff --git a/third_party/gim/gluefactory/datasets/homographies.py b/third_party/gim/gim/gluefactory/datasets/homographies.py similarity index 100% rename from third_party/gim/gluefactory/datasets/homographies.py rename to third_party/gim/gim/gluefactory/datasets/homographies.py diff --git a/third_party/gim/gluefactory/datasets/hpatches.py b/third_party/gim/gim/gluefactory/datasets/hpatches.py similarity index 100% rename from third_party/gim/gluefactory/datasets/hpatches.py rename to third_party/gim/gim/gluefactory/datasets/hpatches.py diff --git a/third_party/gim/gluefactory/datasets/image_folder.py b/third_party/gim/gim/gluefactory/datasets/image_folder.py similarity index 100% rename from third_party/gim/gluefactory/datasets/image_folder.py rename to third_party/gim/gim/gluefactory/datasets/image_folder.py diff --git a/third_party/gim/gluefactory/datasets/image_pairs.py b/third_party/gim/gim/gluefactory/datasets/image_pairs.py similarity index 100% rename from third_party/gim/gluefactory/datasets/image_pairs.py rename to third_party/gim/gim/gluefactory/datasets/image_pairs.py diff --git a/third_party/gim/gluefactory/datasets/megadepth.py b/third_party/gim/gim/gluefactory/datasets/megadepth.py similarity index 100% rename from third_party/gim/gluefactory/datasets/megadepth.py rename to third_party/gim/gim/gluefactory/datasets/megadepth.py diff --git a/third_party/gim/gluefactory/datasets/utils.py b/third_party/gim/gim/gluefactory/datasets/utils.py similarity index 100% rename from third_party/gim/gluefactory/datasets/utils.py rename to third_party/gim/gim/gluefactory/datasets/utils.py diff --git a/third_party/gim/gluefactory/eval/__init__.py b/third_party/gim/gim/gluefactory/eval/__init__.py similarity index 100% rename from third_party/gim/gluefactory/eval/__init__.py rename to third_party/gim/gim/gluefactory/eval/__init__.py diff --git a/third_party/gim/gluefactory/eval/eth3d.py b/third_party/gim/gim/gluefactory/eval/eth3d.py similarity index 100% rename from third_party/gim/gluefactory/eval/eth3d.py rename to third_party/gim/gim/gluefactory/eval/eth3d.py diff --git a/third_party/gim/gluefactory/eval/eval_pipeline.py b/third_party/gim/gim/gluefactory/eval/eval_pipeline.py similarity index 100% rename from third_party/gim/gluefactory/eval/eval_pipeline.py rename to third_party/gim/gim/gluefactory/eval/eval_pipeline.py diff --git a/third_party/gim/gluefactory/eval/hpatches.py b/third_party/gim/gim/gluefactory/eval/hpatches.py similarity index 100% rename from third_party/gim/gluefactory/eval/hpatches.py rename to third_party/gim/gim/gluefactory/eval/hpatches.py diff --git a/third_party/gim/gluefactory/eval/inspect.py b/third_party/gim/gim/gluefactory/eval/inspect.py similarity index 100% rename from third_party/gim/gluefactory/eval/inspect.py rename to third_party/gim/gim/gluefactory/eval/inspect.py diff --git a/third_party/gim/gluefactory/eval/io.py b/third_party/gim/gim/gluefactory/eval/io.py similarity index 100% rename from third_party/gim/gluefactory/eval/io.py rename to third_party/gim/gim/gluefactory/eval/io.py diff --git a/third_party/gim/gluefactory/eval/megadepth1500.py b/third_party/gim/gim/gluefactory/eval/megadepth1500.py similarity index 100% rename from third_party/gim/gluefactory/eval/megadepth1500.py rename to third_party/gim/gim/gluefactory/eval/megadepth1500.py diff --git a/third_party/gim/gluefactory/eval/utils.py b/third_party/gim/gim/gluefactory/eval/utils.py similarity index 100% rename from third_party/gim/gluefactory/eval/utils.py rename to third_party/gim/gim/gluefactory/eval/utils.py diff --git a/third_party/gim/gluefactory/geometry/depth.py b/third_party/gim/gim/gluefactory/geometry/depth.py similarity index 100% rename from third_party/gim/gluefactory/geometry/depth.py rename to third_party/gim/gim/gluefactory/geometry/depth.py diff --git a/third_party/gim/gluefactory/geometry/epipolar.py b/third_party/gim/gim/gluefactory/geometry/epipolar.py similarity index 100% rename from third_party/gim/gluefactory/geometry/epipolar.py rename to third_party/gim/gim/gluefactory/geometry/epipolar.py diff --git a/third_party/gim/gluefactory/geometry/gt_generation.py b/third_party/gim/gim/gluefactory/geometry/gt_generation.py similarity index 100% rename from third_party/gim/gluefactory/geometry/gt_generation.py rename to third_party/gim/gim/gluefactory/geometry/gt_generation.py diff --git a/third_party/gim/gluefactory/geometry/homography.py b/third_party/gim/gim/gluefactory/geometry/homography.py similarity index 100% rename from third_party/gim/gluefactory/geometry/homography.py rename to third_party/gim/gim/gluefactory/geometry/homography.py diff --git a/third_party/gim/gluefactory/geometry/utils.py b/third_party/gim/gim/gluefactory/geometry/utils.py similarity index 100% rename from third_party/gim/gluefactory/geometry/utils.py rename to third_party/gim/gim/gluefactory/geometry/utils.py diff --git a/third_party/gim/gluefactory/geometry/wrappers.py b/third_party/gim/gim/gluefactory/geometry/wrappers.py similarity index 100% rename from third_party/gim/gluefactory/geometry/wrappers.py rename to third_party/gim/gim/gluefactory/geometry/wrappers.py diff --git a/third_party/gim/gluefactory/models/__init__.py b/third_party/gim/gim/gluefactory/models/__init__.py similarity index 100% rename from third_party/gim/gluefactory/models/__init__.py rename to third_party/gim/gim/gluefactory/models/__init__.py diff --git a/third_party/gim/gluefactory/models/backbones/__init__.py b/third_party/gim/gim/gluefactory/models/backbones/__init__.py similarity index 100% rename from third_party/gim/gluefactory/models/backbones/__init__.py rename to third_party/gim/gim/gluefactory/models/backbones/__init__.py diff --git a/third_party/gim/gluefactory/models/backbones/dinov2.py b/third_party/gim/gim/gluefactory/models/backbones/dinov2.py similarity index 100% rename from third_party/gim/gluefactory/models/backbones/dinov2.py rename to third_party/gim/gim/gluefactory/models/backbones/dinov2.py diff --git a/third_party/gim/gluefactory/models/base_model.py b/third_party/gim/gim/gluefactory/models/base_model.py similarity index 100% rename from third_party/gim/gluefactory/models/base_model.py rename to third_party/gim/gim/gluefactory/models/base_model.py diff --git a/third_party/gim/gluefactory/models/cache_loader.py b/third_party/gim/gim/gluefactory/models/cache_loader.py similarity index 100% rename from third_party/gim/gluefactory/models/cache_loader.py rename to third_party/gim/gim/gluefactory/models/cache_loader.py diff --git a/third_party/gim/gluefactory/models/extractors/__init__.py b/third_party/gim/gim/gluefactory/models/extractors/__init__.py similarity index 100% rename from third_party/gim/gluefactory/models/extractors/__init__.py rename to third_party/gim/gim/gluefactory/models/extractors/__init__.py diff --git a/third_party/gim/gluefactory/models/extractors/aliked.py b/third_party/gim/gim/gluefactory/models/extractors/aliked.py similarity index 100% rename from third_party/gim/gluefactory/models/extractors/aliked.py rename to third_party/gim/gim/gluefactory/models/extractors/aliked.py diff --git a/third_party/gim/gluefactory/models/extractors/disk_kornia.py b/third_party/gim/gim/gluefactory/models/extractors/disk_kornia.py similarity index 100% rename from third_party/gim/gluefactory/models/extractors/disk_kornia.py rename to third_party/gim/gim/gluefactory/models/extractors/disk_kornia.py diff --git a/third_party/gim/gluefactory/models/extractors/grid_extractor.py b/third_party/gim/gim/gluefactory/models/extractors/grid_extractor.py similarity index 100% rename from third_party/gim/gluefactory/models/extractors/grid_extractor.py rename to third_party/gim/gim/gluefactory/models/extractors/grid_extractor.py diff --git a/third_party/gim/gluefactory/models/extractors/keynet_affnet_hardnet.py b/third_party/gim/gim/gluefactory/models/extractors/keynet_affnet_hardnet.py similarity index 100% rename from third_party/gim/gluefactory/models/extractors/keynet_affnet_hardnet.py rename to third_party/gim/gim/gluefactory/models/extractors/keynet_affnet_hardnet.py diff --git a/third_party/gim/gluefactory/models/extractors/mixed.py b/third_party/gim/gim/gluefactory/models/extractors/mixed.py similarity index 100% rename from third_party/gim/gluefactory/models/extractors/mixed.py rename to third_party/gim/gim/gluefactory/models/extractors/mixed.py diff --git a/third_party/gim/gluefactory/models/extractors/sift.py b/third_party/gim/gim/gluefactory/models/extractors/sift.py similarity index 100% rename from third_party/gim/gluefactory/models/extractors/sift.py rename to third_party/gim/gim/gluefactory/models/extractors/sift.py diff --git a/third_party/gim/gluefactory/models/extractors/sift_kornia.py b/third_party/gim/gim/gluefactory/models/extractors/sift_kornia.py similarity index 100% rename from third_party/gim/gluefactory/models/extractors/sift_kornia.py rename to third_party/gim/gim/gluefactory/models/extractors/sift_kornia.py diff --git a/third_party/gim/gluefactory/models/extractors/superpoint_open.py b/third_party/gim/gim/gluefactory/models/extractors/superpoint_open.py similarity index 100% rename from third_party/gim/gluefactory/models/extractors/superpoint_open.py rename to third_party/gim/gim/gluefactory/models/extractors/superpoint_open.py diff --git a/third_party/gim/gluefactory/models/lines/__init__.py b/third_party/gim/gim/gluefactory/models/lines/__init__.py similarity index 100% rename from third_party/gim/gluefactory/models/lines/__init__.py rename to third_party/gim/gim/gluefactory/models/lines/__init__.py diff --git a/third_party/gim/gluefactory/models/lines/deeplsd.py b/third_party/gim/gim/gluefactory/models/lines/deeplsd.py similarity index 100% rename from third_party/gim/gluefactory/models/lines/deeplsd.py rename to third_party/gim/gim/gluefactory/models/lines/deeplsd.py diff --git a/third_party/gim/gluefactory/models/lines/lsd.py b/third_party/gim/gim/gluefactory/models/lines/lsd.py similarity index 100% rename from third_party/gim/gluefactory/models/lines/lsd.py rename to third_party/gim/gim/gluefactory/models/lines/lsd.py diff --git a/third_party/gim/gluefactory/models/lines/wireframe.py b/third_party/gim/gim/gluefactory/models/lines/wireframe.py similarity index 100% rename from third_party/gim/gluefactory/models/lines/wireframe.py rename to third_party/gim/gim/gluefactory/models/lines/wireframe.py diff --git a/third_party/gim/gluefactory/models/matchers/__init__.py b/third_party/gim/gim/gluefactory/models/matchers/__init__.py similarity index 100% rename from third_party/gim/gluefactory/models/matchers/__init__.py rename to third_party/gim/gim/gluefactory/models/matchers/__init__.py diff --git a/third_party/gim/gluefactory/models/matchers/adalam.py b/third_party/gim/gim/gluefactory/models/matchers/adalam.py similarity index 100% rename from third_party/gim/gluefactory/models/matchers/adalam.py rename to third_party/gim/gim/gluefactory/models/matchers/adalam.py diff --git a/third_party/gim/gluefactory/models/matchers/depth_matcher.py b/third_party/gim/gim/gluefactory/models/matchers/depth_matcher.py similarity index 100% rename from third_party/gim/gluefactory/models/matchers/depth_matcher.py rename to third_party/gim/gim/gluefactory/models/matchers/depth_matcher.py diff --git a/third_party/gim/gluefactory/models/matchers/gluestick.py b/third_party/gim/gim/gluefactory/models/matchers/gluestick.py similarity index 100% rename from third_party/gim/gluefactory/models/matchers/gluestick.py rename to third_party/gim/gim/gluefactory/models/matchers/gluestick.py diff --git a/third_party/gim/gluefactory/models/matchers/homography_matcher.py b/third_party/gim/gim/gluefactory/models/matchers/homography_matcher.py similarity index 100% rename from third_party/gim/gluefactory/models/matchers/homography_matcher.py rename to third_party/gim/gim/gluefactory/models/matchers/homography_matcher.py diff --git a/third_party/gim/gluefactory/models/matchers/kornia_loftr.py b/third_party/gim/gim/gluefactory/models/matchers/kornia_loftr.py similarity index 100% rename from third_party/gim/gluefactory/models/matchers/kornia_loftr.py rename to third_party/gim/gim/gluefactory/models/matchers/kornia_loftr.py diff --git a/third_party/gim/gluefactory/models/matchers/lightglue.py b/third_party/gim/gim/gluefactory/models/matchers/lightglue.py similarity index 100% rename from third_party/gim/gluefactory/models/matchers/lightglue.py rename to third_party/gim/gim/gluefactory/models/matchers/lightglue.py diff --git a/third_party/gim/gluefactory/models/matchers/lightglue_pretrained.py b/third_party/gim/gim/gluefactory/models/matchers/lightglue_pretrained.py similarity index 100% rename from third_party/gim/gluefactory/models/matchers/lightglue_pretrained.py rename to third_party/gim/gim/gluefactory/models/matchers/lightglue_pretrained.py diff --git a/third_party/gim/gluefactory/models/matchers/nearest_neighbor_matcher.py b/third_party/gim/gim/gluefactory/models/matchers/nearest_neighbor_matcher.py similarity index 100% rename from third_party/gim/gluefactory/models/matchers/nearest_neighbor_matcher.py rename to third_party/gim/gim/gluefactory/models/matchers/nearest_neighbor_matcher.py diff --git a/third_party/gim/gluefactory/models/triplet_pipeline.py b/third_party/gim/gim/gluefactory/models/triplet_pipeline.py similarity index 100% rename from third_party/gim/gluefactory/models/triplet_pipeline.py rename to third_party/gim/gim/gluefactory/models/triplet_pipeline.py diff --git a/third_party/gim/gluefactory/models/two_view_pipeline.py b/third_party/gim/gim/gluefactory/models/two_view_pipeline.py similarity index 100% rename from third_party/gim/gluefactory/models/two_view_pipeline.py rename to third_party/gim/gim/gluefactory/models/two_view_pipeline.py diff --git a/third_party/gim/gluefactory/models/utils/__init__.py b/third_party/gim/gim/gluefactory/models/utils/__init__.py similarity index 100% rename from third_party/gim/gluefactory/models/utils/__init__.py rename to third_party/gim/gim/gluefactory/models/utils/__init__.py diff --git a/third_party/gim/gluefactory/models/utils/losses.py b/third_party/gim/gim/gluefactory/models/utils/losses.py similarity index 100% rename from third_party/gim/gluefactory/models/utils/losses.py rename to third_party/gim/gim/gluefactory/models/utils/losses.py diff --git a/third_party/gim/gluefactory/models/utils/metrics.py b/third_party/gim/gim/gluefactory/models/utils/metrics.py similarity index 100% rename from third_party/gim/gluefactory/models/utils/metrics.py rename to third_party/gim/gim/gluefactory/models/utils/metrics.py diff --git a/third_party/gim/gluefactory/models/utils/misc.py b/third_party/gim/gim/gluefactory/models/utils/misc.py similarity index 100% rename from third_party/gim/gluefactory/models/utils/misc.py rename to third_party/gim/gim/gluefactory/models/utils/misc.py diff --git a/third_party/gim/gluefactory/robust_estimators/__init__.py b/third_party/gim/gim/gluefactory/robust_estimators/__init__.py similarity index 100% rename from third_party/gim/gluefactory/robust_estimators/__init__.py rename to third_party/gim/gim/gluefactory/robust_estimators/__init__.py diff --git a/third_party/gim/gluefactory/robust_estimators/base_estimator.py b/third_party/gim/gim/gluefactory/robust_estimators/base_estimator.py similarity index 100% rename from third_party/gim/gluefactory/robust_estimators/base_estimator.py rename to third_party/gim/gim/gluefactory/robust_estimators/base_estimator.py diff --git a/third_party/gim/gluefactory/robust_estimators/homography/__init__.py b/third_party/gim/gim/gluefactory/robust_estimators/homography/__init__.py similarity index 100% rename from third_party/gim/gluefactory/robust_estimators/homography/__init__.py rename to third_party/gim/gim/gluefactory/robust_estimators/homography/__init__.py diff --git a/third_party/gim/gluefactory/robust_estimators/homography/homography_est.py b/third_party/gim/gim/gluefactory/robust_estimators/homography/homography_est.py similarity index 100% rename from third_party/gim/gluefactory/robust_estimators/homography/homography_est.py rename to third_party/gim/gim/gluefactory/robust_estimators/homography/homography_est.py diff --git a/third_party/gim/gluefactory/robust_estimators/homography/opencv.py b/third_party/gim/gim/gluefactory/robust_estimators/homography/opencv.py similarity index 100% rename from third_party/gim/gluefactory/robust_estimators/homography/opencv.py rename to third_party/gim/gim/gluefactory/robust_estimators/homography/opencv.py diff --git a/third_party/gim/gluefactory/robust_estimators/homography/poselib.py b/third_party/gim/gim/gluefactory/robust_estimators/homography/poselib.py similarity index 100% rename from third_party/gim/gluefactory/robust_estimators/homography/poselib.py rename to third_party/gim/gim/gluefactory/robust_estimators/homography/poselib.py diff --git a/third_party/gim/gluefactory/robust_estimators/relative_pose/__init__.py b/third_party/gim/gim/gluefactory/robust_estimators/relative_pose/__init__.py similarity index 100% rename from third_party/gim/gluefactory/robust_estimators/relative_pose/__init__.py rename to third_party/gim/gim/gluefactory/robust_estimators/relative_pose/__init__.py diff --git a/third_party/gim/gluefactory/robust_estimators/relative_pose/opencv.py b/third_party/gim/gim/gluefactory/robust_estimators/relative_pose/opencv.py similarity index 100% rename from third_party/gim/gluefactory/robust_estimators/relative_pose/opencv.py rename to third_party/gim/gim/gluefactory/robust_estimators/relative_pose/opencv.py diff --git a/third_party/gim/gluefactory/robust_estimators/relative_pose/poselib.py b/third_party/gim/gim/gluefactory/robust_estimators/relative_pose/poselib.py similarity index 100% rename from third_party/gim/gluefactory/robust_estimators/relative_pose/poselib.py rename to third_party/gim/gim/gluefactory/robust_estimators/relative_pose/poselib.py diff --git a/third_party/gim/gluefactory/robust_estimators/relative_pose/pycolmap.py b/third_party/gim/gim/gluefactory/robust_estimators/relative_pose/pycolmap.py similarity index 100% rename from third_party/gim/gluefactory/robust_estimators/relative_pose/pycolmap.py rename to third_party/gim/gim/gluefactory/robust_estimators/relative_pose/pycolmap.py diff --git a/third_party/gim/gluefactory/scripts/__init__.py b/third_party/gim/gim/gluefactory/scripts/__init__.py similarity index 100% rename from third_party/gim/gluefactory/scripts/__init__.py rename to third_party/gim/gim/gluefactory/scripts/__init__.py diff --git a/third_party/gim/gluefactory/scripts/export_local_features.py b/third_party/gim/gim/gluefactory/scripts/export_local_features.py similarity index 100% rename from third_party/gim/gluefactory/scripts/export_local_features.py rename to third_party/gim/gim/gluefactory/scripts/export_local_features.py diff --git a/third_party/gim/gluefactory/scripts/export_megadepth.py b/third_party/gim/gim/gluefactory/scripts/export_megadepth.py similarity index 100% rename from third_party/gim/gluefactory/scripts/export_megadepth.py rename to third_party/gim/gim/gluefactory/scripts/export_megadepth.py diff --git a/third_party/gim/gluefactory/settings.py b/third_party/gim/gim/gluefactory/settings.py similarity index 100% rename from third_party/gim/gluefactory/settings.py rename to third_party/gim/gim/gluefactory/settings.py diff --git a/third_party/gim/gluefactory/superpoint.py b/third_party/gim/gim/gluefactory/superpoint.py similarity index 100% rename from third_party/gim/gluefactory/superpoint.py rename to third_party/gim/gim/gluefactory/superpoint.py diff --git a/third_party/gim/gluefactory/train.py b/third_party/gim/gim/gluefactory/train.py similarity index 100% rename from third_party/gim/gluefactory/train.py rename to third_party/gim/gim/gluefactory/train.py diff --git a/third_party/gim/gluefactory/utils/__init__.py b/third_party/gim/gim/gluefactory/utils/__init__.py similarity index 100% rename from third_party/gim/gluefactory/utils/__init__.py rename to third_party/gim/gim/gluefactory/utils/__init__.py diff --git a/third_party/gim/gluefactory/utils/benchmark.py b/third_party/gim/gim/gluefactory/utils/benchmark.py similarity index 100% rename from third_party/gim/gluefactory/utils/benchmark.py rename to third_party/gim/gim/gluefactory/utils/benchmark.py diff --git a/third_party/gim/gluefactory/utils/export_predictions.py b/third_party/gim/gim/gluefactory/utils/export_predictions.py similarity index 100% rename from third_party/gim/gluefactory/utils/export_predictions.py rename to third_party/gim/gim/gluefactory/utils/export_predictions.py diff --git a/third_party/gim/gluefactory/utils/image.py b/third_party/gim/gim/gluefactory/utils/image.py similarity index 100% rename from third_party/gim/gluefactory/utils/image.py rename to third_party/gim/gim/gluefactory/utils/image.py diff --git a/third_party/gim/gluefactory/utils/misc.py b/third_party/gim/gim/gluefactory/utils/misc.py similarity index 100% rename from third_party/gim/gluefactory/utils/misc.py rename to third_party/gim/gim/gluefactory/utils/misc.py diff --git a/third_party/gim/gluefactory/utils/patches.py b/third_party/gim/gim/gluefactory/utils/patches.py similarity index 100% rename from third_party/gim/gluefactory/utils/patches.py rename to third_party/gim/gim/gluefactory/utils/patches.py diff --git a/third_party/gim/gluefactory/utils/stdout_capturing.py b/third_party/gim/gim/gluefactory/utils/stdout_capturing.py similarity index 100% rename from third_party/gim/gluefactory/utils/stdout_capturing.py rename to third_party/gim/gim/gluefactory/utils/stdout_capturing.py diff --git a/third_party/gim/gluefactory/utils/tensor.py b/third_party/gim/gim/gluefactory/utils/tensor.py similarity index 100% rename from third_party/gim/gluefactory/utils/tensor.py rename to third_party/gim/gim/gluefactory/utils/tensor.py diff --git a/third_party/gim/gluefactory/utils/tools.py b/third_party/gim/gim/gluefactory/utils/tools.py similarity index 100% rename from third_party/gim/gluefactory/utils/tools.py rename to third_party/gim/gim/gluefactory/utils/tools.py diff --git a/third_party/gim/gluefactory/visualization/global_frame.py b/third_party/gim/gim/gluefactory/visualization/global_frame.py similarity index 100% rename from third_party/gim/gluefactory/visualization/global_frame.py rename to third_party/gim/gim/gluefactory/visualization/global_frame.py diff --git a/third_party/gim/gluefactory/visualization/tools.py b/third_party/gim/gim/gluefactory/visualization/tools.py similarity index 100% rename from third_party/gim/gluefactory/visualization/tools.py rename to third_party/gim/gim/gluefactory/visualization/tools.py diff --git a/third_party/gim/gluefactory/visualization/two_view_frame.py b/third_party/gim/gim/gluefactory/visualization/two_view_frame.py similarity index 100% rename from third_party/gim/gluefactory/visualization/two_view_frame.py rename to third_party/gim/gim/gluefactory/visualization/two_view_frame.py diff --git a/third_party/gim/gluefactory/visualization/visualize_batch.py b/third_party/gim/gim/gluefactory/visualization/visualize_batch.py similarity index 100% rename from third_party/gim/gluefactory/visualization/visualize_batch.py rename to third_party/gim/gim/gluefactory/visualization/visualize_batch.py diff --git a/third_party/gim/gluefactory/visualization/viz2d.py b/third_party/gim/gim/gluefactory/visualization/viz2d.py similarity index 100% rename from third_party/gim/gluefactory/visualization/viz2d.py rename to third_party/gim/gim/gluefactory/visualization/viz2d.py diff --git a/third_party/gim/gim/lightglue/__init__.py b/third_party/gim/gim/lightglue/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd8fb3b1e3b54f80fdf70688fb4e4705305a723 --- /dev/null +++ b/third_party/gim/gim/lightglue/__init__.py @@ -0,0 +1,17 @@ +import logging + +# from .utils.experiments import load_experiment # noqa: F401 + +formatter = logging.Formatter( + fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%m/%d/%Y %H:%M:%S" +) +handler = logging.StreamHandler() +handler.setFormatter(formatter) +handler.setLevel(logging.INFO) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(handler) +logger.propagate = False + +__module_name__ = __name__ diff --git a/third_party/gim/gim/lightglue/models/__init__.py b/third_party/gim/gim/lightglue/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9d1a05c66bbc22a711cb968be00985a31a3dfd5 --- /dev/null +++ b/third_party/gim/gim/lightglue/models/__init__.py @@ -0,0 +1,30 @@ +import importlib.util + +from ..utils.tools import get_class +from .base_model import BaseModel + + +def get_model(name): + import_paths = [ + name, + f"{__name__}.{name}", + f"{__name__}.extractors.{name}", # backward compatibility + f"{__name__}.matchers.{name}", # backward compatibility + ] + for path in import_paths: + try: + spec = importlib.util.find_spec(path) + except ModuleNotFoundError: + spec = None + if spec is not None: + try: + return get_class(path, BaseModel) + except AssertionError: + mod = __import__(path, fromlist=[""]) + try: + return mod.__main_model__ + except AttributeError as exc: + print(exc) + continue + + raise RuntimeError(f'Model {name} not found in any of [{" ".join(import_paths)}]') diff --git a/third_party/gim/gim/lightglue/models/base_model.py b/third_party/gim/gim/lightglue/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f66288b9f724468c4409171b9c374c794ae9c9 --- /dev/null +++ b/third_party/gim/gim/lightglue/models/base_model.py @@ -0,0 +1,157 @@ +""" +Base class for trainable models. +""" + +from abc import ABCMeta, abstractmethod +from copy import copy + +import omegaconf +from omegaconf import OmegaConf +from torch import nn + + +class MetaModel(ABCMeta): + def __prepare__(name, bases, **kwds): + total_conf = OmegaConf.create() + for base in bases: + for key in ("base_default_conf", "default_conf"): + update = getattr(base, key, {}) + if isinstance(update, dict): + update = OmegaConf.create(update) + total_conf = OmegaConf.merge(total_conf, update) + return dict(base_default_conf=total_conf) + + +class BaseModel(nn.Module, metaclass=MetaModel): + """ + What the child model is expect to declare: + default_conf: dictionary of the default configuration of the model. + It recursively updates the default_conf of all parent classes, and + it is updated by the user-provided configuration passed to __init__. + Configurations can be nested. + + required_data_keys: list of expected keys in the input data dictionary. + + strict_conf (optional): boolean. If false, BaseModel does not raise + an error when the user provides an unknown configuration entry. + + _init(self, conf): initialization method, where conf is the final + configuration object (also accessible with `self.conf`). Accessing + unknown configuration entries will raise an error. + + _forward(self, data): method that returns a dictionary of batched + prediction tensors based on a dictionary of batched input data tensors. + + loss(self, pred, data): method that returns a dictionary of losses, + computed from model predictions and input data. Each loss is a batch + of scalars, i.e. a torch.Tensor of shape (B,). + The total loss to be optimized has the key `'total'`. + + metrics(self, pred, data): method that returns a dictionary of metrics, + each as a batch of scalars. + """ + + default_conf = { + "name": None, + "trainable": True, # if false: do not optimize this model parameters + "freeze_batch_normalization": False, # use test-time statistics + "timeit": False, # time forward pass + } + required_data_keys = [] + strict_conf = False + + are_weights_initialized = False + + def __init__(self, conf): + """Perform some logic and call the _init method of the child model.""" + super().__init__() + default_conf = OmegaConf.merge( + self.base_default_conf, OmegaConf.create(self.default_conf) + ) + if self.strict_conf: + OmegaConf.set_struct(default_conf, True) + + # fixme: backward compatibility + if "pad" in conf and "pad" not in default_conf: # backward compat. + with omegaconf.read_write(conf): + with omegaconf.open_dict(conf): + conf["interpolation"] = {"pad": conf.pop("pad")} + + if isinstance(conf, dict): + conf = OmegaConf.create(conf) + self.conf = conf = OmegaConf.merge(default_conf, conf) + OmegaConf.set_readonly(conf, True) + OmegaConf.set_struct(conf, True) + self.required_data_keys = copy(self.required_data_keys) + self._init(conf) + + if not conf.trainable: + for p in self.parameters(): + p.requires_grad = False + + def train(self, mode=True): + super().train(mode) + + def freeze_bn(module): + if isinstance(module, nn.modules.batchnorm._BatchNorm): + module.eval() + + if self.conf.freeze_batch_normalization: + self.apply(freeze_bn) + + return self + + def forward(self, data): + """Check the data and call the _forward method of the child model.""" + + def recursive_key_check(expected, given): + for key in expected: + assert key in given, f"Missing key {key} in data" + if isinstance(expected, dict): + recursive_key_check(expected[key], given[key]) + + recursive_key_check(self.required_data_keys, data) + return self._forward(data) + + @abstractmethod + def _init(self, conf): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def _forward(self, data): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def loss(self, pred, data): + """To be implemented by the child class.""" + raise NotImplementedError + + def load_state_dict(self, *args, **kwargs): + """Load the state dict of the model, and set the model to initialized.""" + ret = super().load_state_dict(*args, **kwargs) + self.set_initialized() + return ret + + def is_initialized(self): + """Recursively check if the model is initialized, i.e. weights are loaded""" + is_initialized = True # initialize to true and perform recursive and + for _, w in self.named_children(): + if isinstance(w, BaseModel): + # if children is BaseModel, we perform recursive check + is_initialized = is_initialized and w.is_initialized() + else: + # else, we check if self is initialized or the children has no params + n_params = len(list(w.parameters())) + is_initialized = is_initialized and ( + n_params == 0 or self.are_weights_initialized + ) + return is_initialized + + def set_initialized(self, to: bool = True): + """Recursively set the initialization state.""" + self.are_weights_initialized = to + for _, w in self.named_parameters(): + if isinstance(w, BaseModel): + w.set_initialized(to) diff --git a/third_party/gim/gim/lightglue/models/matchers/__init__.py b/third_party/gim/gim/lightglue/models/matchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/gim/gim/lightglue/models/matchers/lightglue.py b/third_party/gim/gim/lightglue/models/matchers/lightglue.py new file mode 100644 index 0000000000000000000000000000000000000000..3dfda2a2968e038ee2d90ecff0533af9d3a14484 --- /dev/null +++ b/third_party/gim/gim/lightglue/models/matchers/lightglue.py @@ -0,0 +1,632 @@ +import warnings +from pathlib import Path +from typing import Callable, List, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from omegaconf import OmegaConf +from torch import nn +from torch.utils.checkpoint import checkpoint + +# from ...settings import DATA_PATH +# from ..utils.losses import NLLLoss +# from ..utils.metrics import matcher_metrics + +FLASH_AVAILABLE = hasattr(F, "scaled_dot_product_attention") + +torch.backends.cudnn.deterministic = True + + +@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) +def normalize_keypoints( + kpts: torch.Tensor, size: Optional[torch.Tensor] = None +) -> torch.Tensor: + if size is None: + size = 1 + kpts.max(-2).values - kpts.min(-2).values + elif not isinstance(size, torch.Tensor): + size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype) + size = size.to(kpts) + shift = size / 2 + scale = size.max(-1).values / 2 + kpts = (kpts - shift[..., None, :]) / scale[..., None, None] + return kpts + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) + + +def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (t * freqs[0]) + (rotate_half(t) * freqs[1]) + + +class LearnableFourierPositionalEncoding(nn.Module): + def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None: + super().__init__() + F_dim = F_dim if F_dim is not None else dim + self.gamma = gamma + self.Wr = nn.Linear(M, F_dim // 2, bias=False) + nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """encode position vector""" + projected = self.Wr(x) + cosines, sines = torch.cos(projected), torch.sin(projected) + emb = torch.stack([cosines, sines], 0).unsqueeze(-3) + return emb.repeat_interleave(2, dim=-1) + + +class TokenConfidence(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid()) + self.loss_fn = nn.BCEWithLogitsLoss(reduction="none") + + def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): + """get confidence tokens""" + return ( + self.token(desc0.detach()).squeeze(-1), + self.token(desc1.detach()).squeeze(-1), + ) + + def loss(self, desc0, desc1, la_now, la_final): + logit0 = self.token[0](desc0.detach()).squeeze(-1) + logit1 = self.token[0](desc1.detach()).squeeze(-1) + la_now, la_final = la_now.detach(), la_final.detach() + correct0 = ( + la_final[:, :-1, :].max(-1).indices == la_now[:, :-1, :].max(-1).indices + ) + correct1 = ( + la_final[:, :, :-1].max(-2).indices == la_now[:, :, :-1].max(-2).indices + ) + return ( + self.loss_fn(logit0, correct0.float()).mean(-1) + + self.loss_fn(logit1, correct1.float()).mean(-1) + ) / 2.0 + + +class Attention(nn.Module): + def __init__(self, allow_flash: bool) -> None: + super().__init__() + if allow_flash and not FLASH_AVAILABLE: + warnings.warn( + "FlashAttention is not available. For optimal speed, " + "consider installing torch >= 2.0 or flash-attn.", + stacklevel=2, + ) + self.enable_flash = allow_flash and FLASH_AVAILABLE + + if FLASH_AVAILABLE: + torch.backends.cuda.enable_flash_sdp(allow_flash) + + def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.enable_flash and q.device.type == "cuda": + # use torch 2.0 scaled_dot_product_attention with flash + if FLASH_AVAILABLE: + args = [x.half().contiguous() for x in [q, k, v]] + v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype) + return v if mask is None else v.nan_to_num() + elif FLASH_AVAILABLE: + args = [x.contiguous() for x in [q, k, v]] + v = F.scaled_dot_product_attention(*args, attn_mask=mask) + return v if mask is None else v.nan_to_num() + else: + s = q.shape[-1] ** -0.5 + sim = torch.einsum("...id,...jd->...ij", q, k) * s + if mask is not None: + sim.masked_fill(~mask, -float("inf")) + attn = F.softmax(sim, -1) + return torch.einsum("...ij,...jd->...id", attn, v) + + +class SelfBlock(nn.Module): + def __init__( + self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0 + self.head_dim = self.embed_dim // num_heads + self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) + self.inner_attn = Attention(flash) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.ffn = nn.Sequential( + nn.Linear(2 * embed_dim, 2 * embed_dim), + nn.LayerNorm(2 * embed_dim, elementwise_affine=True), + nn.GELU(), + nn.Linear(2 * embed_dim, embed_dim), + ) + + def forward( + self, + x: torch.Tensor, + encoding: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv = self.Wqkv(x) + qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2) + q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2] + q = apply_cached_rotary_emb(encoding, q) + k = apply_cached_rotary_emb(encoding, k) + context = self.inner_attn(q, k, v, mask=mask) + message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2)) + return x + self.ffn(torch.cat([x, message], -1)) + + +class CrossBlock(nn.Module): + def __init__( + self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True + ) -> None: + super().__init__() + self.heads = num_heads + dim_head = embed_dim // num_heads + self.scale = dim_head**-0.5 + inner_dim = dim_head * num_heads + self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias) + self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias) + self.ffn = nn.Sequential( + nn.Linear(2 * embed_dim, 2 * embed_dim), + nn.LayerNorm(2 * embed_dim, elementwise_affine=True), + nn.GELU(), + nn.Linear(2 * embed_dim, embed_dim), + ) + if flash and FLASH_AVAILABLE: + self.flash = Attention(True) + else: + self.flash = None + + def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor): + return func(x0), func(x1) + + def forward( + self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> List[torch.Tensor]: + qk0, qk1 = self.map_(self.to_qk, x0, x1) + v0, v1 = self.map_(self.to_v, x0, x1) + qk0, qk1, v0, v1 = map( + lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2), + (qk0, qk1, v0, v1), + ) + if self.flash is not None and qk0.device.type == "cuda": + m0 = self.flash(qk0, qk1, v1, mask) + m1 = self.flash( + qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None + ) + else: + qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5 + sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1) + if mask is not None: + sim = sim.masked_fill(~mask, -float("inf")) + attn01 = F.softmax(sim, dim=-1) + attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1) + m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1) + m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0) + if mask is not None: + m0, m1 = m0.nan_to_num(), m1.nan_to_num() + m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1) + m0, m1 = self.map_(self.to_out, m0, m1) + x0 = x0 + self.ffn(torch.cat([x0, m0], -1)) + x1 = x1 + self.ffn(torch.cat([x1, m1], -1)) + return x0, x1 + + +class TransformerLayer(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.self_attn = SelfBlock(*args, **kwargs) + self.cross_attn = CrossBlock(*args, **kwargs) + + def forward( + self, + desc0, + desc1, + encoding0, + encoding1, + mask0: Optional[torch.Tensor] = None, + mask1: Optional[torch.Tensor] = None, + ): + if mask0 is not None and mask1 is not None: + return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1) + else: + desc0 = self.self_attn(desc0, encoding0) + desc1 = self.self_attn(desc1, encoding1) + return self.cross_attn(desc0, desc1) + + # This part is compiled and allows padding inputs + def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1): + mask = mask0 & mask1.transpose(-1, -2) + mask0 = mask0 & mask0.transpose(-1, -2) + mask1 = mask1 & mask1.transpose(-1, -2) + desc0 = self.self_attn(desc0, encoding0, mask0) + desc1 = self.self_attn(desc1, encoding1, mask1) + return self.cross_attn(desc0, desc1, mask) + + +def sigmoid_log_double_softmax( + sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor +) -> torch.Tensor: + """create the log assignment matrix from logits and similarity""" + b, m, n = sim.shape + certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2) + scores0 = F.log_softmax(sim, 2) + scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) + scores = sim.new_full((b, m + 1, n + 1), 0) + scores[:, :m, :n] = scores0 + scores1 + certainties + scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1)) + scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1)) + return scores + + +class MatchAssignment(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + self.matchability = nn.Linear(dim, 1, bias=True) + self.final_proj = nn.Linear(dim, dim, bias=True) + + def forward(self, desc0: torch.Tensor, desc1: torch.Tensor): + """build assignment matrix from descriptors""" + mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1) + _, _, d = mdesc0.shape + mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25 + sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1) + z0 = self.matchability(desc0) + z1 = self.matchability(desc1) + scores = sigmoid_log_double_softmax(sim, z0, z1) + return scores, sim + + def get_matchability(self, desc: torch.Tensor): + return torch.sigmoid(self.matchability(desc)).squeeze(-1) + + +def filter_matches(scores: torch.Tensor, th: float): + """obtain matches from a log assignment matrix [Bx M+1 x N+1]""" + max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) + m0, m1 = max0.indices, max1.indices + indices0 = torch.arange(m0.shape[1], device=m0.device)[None] + indices1 = torch.arange(m1.shape[1], device=m1.device)[None] + mutual0 = indices0 == m1.gather(1, m0) + mutual1 = indices1 == m0.gather(1, m1) + max0_exp = max0.values.exp() + zero = max0_exp.new_tensor(0) + mscores0 = torch.where(mutual0, max0_exp, zero) + mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero) + valid0 = mutual0 & (mscores0 > th) + valid1 = mutual1 & valid0.gather(1, m1) + m0 = torch.where(valid0, m0, -1) + m1 = torch.where(valid1, m1, -1) + return m0, m1, mscores0, mscores1 + + +class LightGlue(nn.Module): + default_conf = { + "name": "lightglue", # just for interfacing + "input_dim": 256, # input descriptor dimension (autoselected from weights) + "add_scale_ori": False, + "descriptor_dim": 256, + "n_layers": 9, + "num_heads": 4, + "flash": False, # enable FlashAttention if available. + "mp": False, # enable mixed precision + "depth_confidence": -1, # early stopping, disable with -1 + "width_confidence": -1, # point pruning, disable with -1 + "filter_threshold": 0.0, # match threshold + "checkpointed": False, + "weights": "superpoint_lightglue", # either a path or the name of pretrained weights (disk, ...) + "weights_from_version": "v0.1_arxiv", + "loss": { + "gamma": 1.0, + "fn": "nll", + "nll_balancing": 0.5, + }, + } + + required_data_keys = ["keypoints0", "keypoints1", "descriptors0", "descriptors1"] + + url = "https://github.com/cvg/LightGlue/releases/download/{}/{}.pth" + + def __init__(self, conf) -> None: + super().__init__() + self.conf = conf = OmegaConf.merge(self.default_conf, conf) + if conf.input_dim != conf.descriptor_dim: + self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True) + else: + self.input_proj = nn.Identity() + + head_dim = conf.descriptor_dim // conf.num_heads + self.posenc = LearnableFourierPositionalEncoding( + 2 + 2 * conf.add_scale_ori, head_dim, head_dim + ) + + h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim + + self.transformers = nn.ModuleList( + [TransformerLayer(d, h, conf.flash) for _ in range(n)] + ) + + self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)]) + self.token_confidence = nn.ModuleList( + [TokenConfidence(d) for _ in range(n - 1)] + ) + + # self.loss_fn = NLLLoss(conf.loss) + + # state_dict = None + # if conf.weights is not None: + # # weights can be either a path or an existing file from official LG + # if Path(conf.weights).exists(): + # state_dict = torch.load(conf.weights, map_location="cpu") + # elif (Path(DATA_PATH) / conf.weights).exists(): + # state_dict = torch.load( + # str(DATA_PATH / conf.weights), map_location="cpu" + # ) + # elif (Path('weights') / (conf.weights + '.pth')).exists(): + # state_dict = torch.load( + # str(Path('weights') / (conf.weights + '.pth')), map_location="cpu" + # ) + # print(f"Readed weights from {Path('weights') / (conf.weights + '.pth')}") + # else: + # fname = ( + # f"{conf.weights}_{conf.weights_from_version}".replace(".", "-") + # + ".pth" + # ) + # state_dict = torch.hub.load_state_dict_from_url( + # self.url.format(conf.weights_from_version, conf.weights), + # file_name=fname, + # ) + # + # if state_dict: + # # rename old state dict entries + # for i in range(self.conf.n_layers): + # pattern = f"self_attn.{i}", f"transformers.{i}.self_attn" + # state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + # pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn" + # state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + # self.load_state_dict(state_dict, strict=False) + # print(f"Loaded weights from {conf.weights}") + + def compile(self, mode="reduce-overhead"): + if self.conf.width_confidence != -1: + warnings.warn( + "Point pruning is partially disabled for compiled forward.", + stacklevel=2, + ) + + for i in range(self.conf.n_layers): + self.transformers[i] = torch.compile( + self.transformers[i], mode=mode, fullgraph=True + ) + + def forward(self, data: dict) -> dict: + for key in self.required_data_keys: + assert key in data, f"Missing key {key} in data" + + kpts0, kpts1 = data["keypoints0"], data["keypoints1"] + b, m, _ = kpts0.shape + b, n, _ = kpts1.shape + device = kpts0.device + # if "view0" in data.keys() and "view1" in data.keys(): + size0 = data["resize0"][:, [1, 0]] + size1 = data["resize1"][:, [1, 0]] + kpts0 = normalize_keypoints(kpts0, size0).clone() + kpts1 = normalize_keypoints(kpts1, size1).clone() + + if self.conf.add_scale_ori: + sc0, o0 = data["scales0"], data["oris0"] + sc1, o1 = data["scales1"], data["oris1"] + kpts0 = torch.cat( + [ + kpts0, + sc0 if sc0.dim() == 3 else sc0[..., None], + o0 if o0.dim() == 3 else o0[..., None], + ], + -1, + ) + kpts1 = torch.cat( + [ + kpts1, + sc1 if sc1.dim() == 3 else sc1[..., None], + o1 if o1.dim() == 3 else o1[..., None], + ], + -1, + ) + + desc0 = data["descriptors0"].contiguous() + desc1 = data["descriptors1"].contiguous() + + assert desc0.shape[-1] == self.conf.input_dim + assert desc1.shape[-1] == self.conf.input_dim + if torch.is_autocast_enabled(): + desc0 = desc0.half() + desc1 = desc1.half() + desc0 = self.input_proj(desc0) + desc1 = self.input_proj(desc1) + # cache positional embeddings + encoding0 = self.posenc(kpts0) + encoding1 = self.posenc(kpts1) + + # GNN + final_proj + assignment + do_early_stop = self.conf.depth_confidence > 0 and not self.training + do_point_pruning = self.conf.width_confidence > 0 and not self.training + + all_desc0, all_desc1 = [], [] + + if do_point_pruning: + ind0 = torch.arange(0, m, device=device)[None] + ind1 = torch.arange(0, n, device=device)[None] + # We store the index of the layer at which pruning is detected. + prune0 = torch.ones_like(ind0) + prune1 = torch.ones_like(ind1) + token0, token1 = None, None + for i in range(self.conf.n_layers): + if self.conf.checkpointed and self.training: + desc0, desc1 = checkpoint( + self.transformers[i], desc0, desc1, encoding0, encoding1 + ) + else: + desc0, desc1 = self.transformers[i](desc0, desc1, encoding0, encoding1) + if self.training or i == self.conf.n_layers - 1: + all_desc0.append(desc0) + all_desc1.append(desc1) + continue # no early stopping or adaptive width at last layer + + # only for eval + if do_early_stop: + assert b == 1 + token0, token1 = self.token_confidence[i](desc0, desc1) + if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n): + break + if do_point_pruning: + assert b == 1 + scores0 = self.log_assignment[i].get_matchability(desc0) + prunemask0 = self.get_pruning_mask(token0, scores0, i) + keep0 = torch.where(prunemask0)[1] + ind0 = ind0.index_select(1, keep0) + desc0 = desc0.index_select(1, keep0) + encoding0 = encoding0.index_select(-2, keep0) + prune0[:, ind0] += 1 + scores1 = self.log_assignment[i].get_matchability(desc1) + prunemask1 = self.get_pruning_mask(token1, scores1, i) + keep1 = torch.where(prunemask1)[1] + ind1 = ind1.index_select(1, keep1) + desc1 = desc1.index_select(1, keep1) + encoding1 = encoding1.index_select(-2, keep1) + prune1[:, ind1] += 1 + + desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] + scores, _ = self.log_assignment[i](desc0, desc1) + m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold) + matches, mscores = [], [] + for k in range(b): + if self.training: break + valid = m0[k] > -1 + m_indices_0 = torch.where(valid)[0] + m_indices_1 = m0[k][valid] + if do_point_pruning: + m_indices_0 = ind0[k, m_indices_0] + m_indices_1 = ind1[k, m_indices_1] + matches.append(torch.stack([m_indices_0, m_indices_1], -1)) + mscores.append(mscores0[k][valid]) + + if do_point_pruning: + m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype) + m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype) + m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0))) + m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0))) + mscores0_ = torch.zeros((b, m), device=mscores0.device) + mscores1_ = torch.zeros((b, n), device=mscores1.device) + mscores0_[:, ind0] = mscores0 + mscores1_[:, ind1] = mscores1 + m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_ + else: + prune0 = torch.ones_like(mscores0) * self.conf.n_layers + prune1 = torch.ones_like(mscores1) * self.conf.n_layers + + pred = { + "matches0": m0, + "matches1": m1, + "matching_scores0": mscores0, + "matching_scores1": mscores1, + "ref_descriptors0": torch.stack(all_desc0, 1), + "ref_descriptors1": torch.stack(all_desc1, 1), + "log_assignment": scores, + "stop": i + 1, + "matches": matches, + "scores": mscores, + "prune0": prune0, + "prune1": prune1, + } + + return pred + + def confidence_threshold(self, layer_index: int) -> float: + """scaled confidence threshold""" + threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers) + return np.clip(threshold, 0, 1) + + def get_pruning_mask( + self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int + ) -> torch.Tensor: + """mask points which should be removed""" + keep = scores > (1 - self.conf.width_confidence) + if confidences is not None: # Low-confidence points are never pruned. + keep |= confidences <= self.confidence_thresholds[layer_index] + return keep + + def check_if_stop( + self, + confidences0: torch.Tensor, + confidences1: torch.Tensor, + layer_index: int, + num_points: int, + ) -> torch.Tensor: + """evaluate stopping condition""" + confidences = torch.cat([confidences0, confidences1], -1) + threshold = self.confidence_thresholds[layer_index] + ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points + return ratio_confident > self.conf.depth_confidence + + def pruning_min_kpts(self, device: torch.device): + if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda": + return self.pruning_keypoint_thresholds["flash"] + else: + return self.pruning_keypoint_thresholds[device.type] + + def loss(self, pred, data): + def loss_params(pred, i): + la, _ = self.log_assignment[i]( + pred["ref_descriptors0"][:, i], pred["ref_descriptors1"][:, i] + ) + return { + "log_assignment": la, + } + + sum_weights = 1.0 + nll, gt_weights, loss_metrics = self.loss_fn(loss_params(pred, -1), data) + N = pred["ref_descriptors0"].shape[1] + losses = {"total": nll, "last": nll.clone().detach(), **loss_metrics} + + if self.training: + losses["confidence"] = 0.0 + + # B = pred['log_assignment'].shape[0] + losses["row_norm"] = pred["log_assignment"].exp()[:, :-1].sum(2).mean(1) + for i in range(N - 1): + params_i = loss_params(pred, i) + nll, _, _ = self.loss_fn(params_i, data, weights=gt_weights) + + if self.conf.loss.gamma > 0.0: + weight = self.conf.loss.gamma ** (N - i - 1) + else: + weight = i + 1 + sum_weights += weight + losses["total"] = losses["total"] + nll * weight + + losses["confidence"] += self.token_confidence[i].loss( + pred["ref_descriptors0"][:, i], + pred["ref_descriptors1"][:, i], + params_i["log_assignment"], + pred["log_assignment"], + ) / (N - 1) + + del params_i + losses["total"] /= sum_weights + + # confidences + if self.training: + losses["total"] = losses["total"] + losses["confidence"] + + if not self.training: + # add metrics + metrics = matcher_metrics(pred, data) + else: + metrics = {} + return losses, metrics + + +__main_model__ = LightGlue diff --git a/third_party/gim/gim/lightglue/models/utils/__init__.py b/third_party/gim/gim/lightglue/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/gim/gim/lightglue/models/utils/misc.py b/third_party/gim/gim/lightglue/models/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e86d1add0e23a042963d878e484f0c582ff8b41c --- /dev/null +++ b/third_party/gim/gim/lightglue/models/utils/misc.py @@ -0,0 +1,70 @@ +import math +from typing import List, Optional, Tuple + +import torch + + +def to_sequence(map): + return map.flatten(-2).transpose(-1, -2) + + +def to_map(sequence): + n = sequence.shape[-2] + e = math.isqrt(n) + assert e * e == n + assert e * e == n + sequence.transpose(-1, -2).unflatten(-1, [e, e]) + + +def pad_to_length( + x, + length: int, + pad_dim: int = -2, + mode: str = "zeros", # zeros, ones, random, random_c + bounds: Tuple[int] = (None, None), +): + shape = list(x.shape) + d = x.shape[pad_dim] + assert d <= length + if d == length: + return x + shape[pad_dim] = length - d + + low, high = bounds + + if mode == "zeros": + xn = torch.zeros(*shape, device=x.device, dtype=x.dtype) + elif mode == "ones": + xn = torch.ones(*shape, device=x.device, dtype=x.dtype) + elif mode == "random": + low = low if low is not None else x.min() + high = high if high is not None else x.max() + xn = torch.empty(*shape, device=x.device).uniform_(low, high) + elif mode == "random_c": + low, high = bounds # we use the bounds as fallback for empty seq. + xn = torch.cat( + [ + torch.empty(*shape[:-1], 1, device=x.device).uniform_( + x[..., i].min() if d > 0 else low, + x[..., i].max() if d > 0 else high, + ) + for i in range(shape[-1]) + ], + dim=-1, + ) + else: + raise ValueError(mode) + return torch.cat([x, xn], dim=pad_dim) + + +def pad_and_stack( + sequences: List[torch.Tensor], + length: Optional[int] = None, + pad_dim: int = -2, + **kwargs, +): + if length is None: + length = max([x.shape[pad_dim] for x in sequences]) + + y = torch.stack([pad_to_length(x, length, pad_dim, **kwargs) for x in sequences], 0) + return y diff --git a/third_party/gim/gim/lightglue/superpoint.py b/third_party/gim/gim/lightglue/superpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e68a47db5c168d6d51fdafd95c9a0e4225b6d70b --- /dev/null +++ b/third_party/gim/gim/lightglue/superpoint.py @@ -0,0 +1,358 @@ +""" +# %BANNER_BEGIN% +# --------------------------------------------------------------------- +# %COPYRIGHT_BEGIN% +# +# Magic Leap, Inc. ("COMPANY") CONFIDENTIAL +# +# Unpublished Copyright (c) 2020 +# Magic Leap, Inc., All Rights Reserved. +# +# NOTICE: All information contained herein is, and remains the property +# of COMPANY. The intellectual and technical concepts contained herein +# are proprietary to COMPANY and may be covered by U.S. and Foreign +# Patents, patents in process, and are protected by trade secret or +# copyright law. Dissemination of this information or reproduction of +# this material is strictly forbidden unless prior written permission is +# obtained from COMPANY. Access to the source code contained herein is +# hereby forbidden to anyone except current COMPANY employees, managers +# or contractors who have executed Confidentiality and Non-disclosure +# agreements explicitly covering such access. +# +# The copyright notice above does not evidence any actual or intended +# publication or disclosure of this source code, which includes +# information that is confidential and/or proprietary, and is a trade +# secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, +# PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS +# SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS +# STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND +# INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE +# CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS +# TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, +# USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. +# +# %COPYRIGHT_END% +# ---------------------------------------------------------------------- +# %AUTHORS_BEGIN% +# +# Originating Authors: Paul-Edouard Sarlin +# +# %AUTHORS_END% +# --------------------------------------------------------------------*/ +# %BANNER_END% + +Described in: + SuperPoint: Self-Supervised Interest Point Detection and Description, + Daniel DeTone, Tomasz Malisiewicz, Andrew Rabinovich, CVPRW 2018. + +Original code: github.com/MagicLeapResearch/SuperPointPretrainedNetwork + +Adapted by Philipp Lindenberger (Phil26AT) +""" +import os.path + +import torch +from torch import nn + +from networks.lightglue.models.base_model import BaseModel +from networks.lightglue.models.utils.misc import pad_and_stack + + +def simple_nms(scores, radius): + """Perform non maximum suppression on the heatmap using max-pooling. + This method does not suppress contiguous points that have the same score. + Args: + scores: the score heatmap of size `(B, H, W)`. + radius: an integer scalar, the radius of the NMS window. + """ + + def max_pool(x): + return torch.nn.functional.max_pool2d( + x, kernel_size=radius * 2 + 1, stride=1, padding=radius + ) + + zeros = torch.zeros_like(scores) + max_mask = scores == max_pool(scores) + for _ in range(2): + supp_mask = max_pool(max_mask.float()) > 0 + supp_scores = torch.where(supp_mask, zeros, scores) + new_max_mask = supp_scores == max_pool(supp_scores) + max_mask = max_mask | (new_max_mask & (~supp_mask)) + return torch.where(max_mask, scores, zeros) + + +def top_k_keypoints(keypoints, scores, k): + if k >= len(keypoints): + return keypoints, scores + scores, indices = torch.topk(scores, k, dim=0, sorted=True) + return keypoints[indices], scores + + +def sample_k_keypoints(keypoints, scores, k): + if k >= len(keypoints): + return keypoints, scores + indices = torch.multinomial(scores, k, replacement=False) + return keypoints[indices], scores[indices] + + +def soft_argmax_refinement(keypoints, scores, radius: int): + width = 2 * radius + 1 + sum_ = torch.nn.functional.avg_pool2d( + scores[:, None], width, 1, radius, divisor_override=1 + ) + ar = torch.arange(-radius, radius + 1).to(scores) + kernel_x = ar[None].expand(width, -1)[None, None] + dx = torch.nn.functional.conv2d(scores[:, None], kernel_x, padding=radius) + dy = torch.nn.functional.conv2d( + scores[:, None], kernel_x.transpose(2, 3), padding=radius + ) + dydx = torch.stack([dy[:, 0], dx[:, 0]], -1) / sum_[:, 0, :, :, None] + refined_keypoints = [] + for i, kpts in enumerate(keypoints): + delta = dydx[i][tuple(kpts.t())] + refined_keypoints.append(kpts.float() + delta) + return refined_keypoints + + +# Legacy (broken) sampling of the descriptors +def sample_descriptors(keypoints, descriptors, s): + b, c, h, w = descriptors.shape + keypoints = keypoints - s / 2 + 0.5 + keypoints /= torch.tensor( + [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)], + ).to( + keypoints + )[None] + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + args = {"align_corners": True} if torch.__version__ >= "1.3" else {} + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args + ) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1 + ) + return descriptors + + +# The original keypoint sampling is incorrect. We patch it here but +# keep the original one above for legacy. +def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8): + """Interpolate descriptors at keypoint locations""" + b, c, h, w = descriptors.shape + keypoints = keypoints / (keypoints.new_tensor([w, h]) * s) + keypoints = keypoints * 2 - 1 # normalize to (-1, 1) + descriptors = torch.nn.functional.grid_sample( + descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False + ) + descriptors = torch.nn.functional.normalize( + descriptors.reshape(b, c, -1), p=2, dim=1 + ) + return descriptors + + +class SuperPoint(BaseModel): + default_conf = { + "has_detector": True, + "has_descriptor": True, + "descriptor_dim": 256, + # Inference + "sparse_outputs": True, + "dense_outputs": False, + "nms_radius": 4, + "refinement_radius": 0, + "detection_threshold": 0.005, + "max_num_keypoints": -1, + "max_num_keypoints_val": None, + "force_num_keypoints": False, + "randomize_keypoints_training": False, + "remove_borders": 4, + "legacy_sampling": True, # True to use the old broken sampling + } + required_data_keys = ["image"] + + # checkpoint_url = "https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/weights/superpoint_v1.pth" # noqa: E501 + + def _init(self, conf): + self.relu = nn.ReLU(inplace=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 + + self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) + self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) + self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) + self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) + self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) + self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) + self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) + self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) + + if conf.has_detector: + self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) + for param in self.convPa.parameters(): + param.requires_grad = False + for param in self.convPb.parameters(): + param.requires_grad = False + + if conf.has_descriptor: + self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) + self.convDb = nn.Conv2d( + c5, conf.descriptor_dim, kernel_size=1, stride=1, padding=0 + ) + + # self.load_state_dict(torch.load(os.path.join('weights', 'superpoint_v1.pth'))) + + def _forward(self, data): + image = data["image"] + if image.shape[1] == 3: # RGB + scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1) + image = (image * scale).sum(1, keepdim=True) + + # Shared Encoder + x = self.relu(self.conv1a(image)) + x = self.relu(self.conv1b(x)) + x = self.pool(x) + x = self.relu(self.conv2a(x)) + x = self.relu(self.conv2b(x)) + x = self.pool(x) + x = self.relu(self.conv3a(x)) + x = self.relu(self.conv3b(x)) + x = self.pool(x) + x = self.relu(self.conv4a(x)) + x = self.relu(self.conv4b(x)) + + pred = {} + if self.conf.has_detector: + # Compute the dense keypoint scores + cPa = self.relu(self.convPa(x)) + scores = self.convPb(cPa) + scores = torch.nn.functional.softmax(scores, 1)[:, :-1] + b, c, h, w = scores.shape + scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) + scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8) + pred["keypoint_scores"] = dense_scores = scores + if self.conf.has_descriptor: + # Compute the dense descriptors + cDa = self.relu(self.convDa(x)) + dense_desc = self.convDb(cDa) + dense_desc = torch.nn.functional.normalize(dense_desc, p=2, dim=1) + pred["descriptors"] = dense_desc + + if self.conf.sparse_outputs: + assert self.conf.has_detector and self.conf.has_descriptor + + scores = simple_nms(scores, self.conf.nms_radius) + + # Discard keypoints near the image borders + if self.conf.remove_borders: + scores[:, : self.conf.remove_borders] = -1 + scores[:, :, : self.conf.remove_borders] = -1 + if "image_size" in data: + for i in range(scores.shape[0]): + w, h = data["image_size"][i] + scores[i, int(h.item()) - self.conf.remove_borders :] = -1 + scores[i, :, int(w.item()) - self.conf.remove_borders :] = -1 + else: + scores[:, -self.conf.remove_borders :] = -1 + scores[:, :, -self.conf.remove_borders :] = -1 + + # Extract keypoints + best_kp = torch.where(scores > self.conf.detection_threshold) + scores = scores[best_kp] + + # Separate into batches + keypoints = [ + torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b) + ] + scores = [scores[best_kp[0] == i] for i in range(b)] + + # Keep the k keypoints with highest score + max_kps = self.conf.max_num_keypoints + + # for val we allow different + if not self.training and self.conf.max_num_keypoints_val is not None: + max_kps = self.conf.max_num_keypoints_val + + # Keep the k keypoints with highest score + if max_kps > 0: + if self.conf.randomize_keypoints_training and self.training: + # instead of selecting top-k, sample k by score weights + keypoints, scores = list( + zip( + *[ + sample_k_keypoints(k, s, max_kps) + for k, s in zip(keypoints, scores) + ] + ) + ) + else: + keypoints, scores = list( + zip( + *[ + top_k_keypoints(k, s, max_kps) + for k, s in zip(keypoints, scores) + ] + ) + ) + keypoints, scores = list(keypoints), list(scores) + + if self.conf["refinement_radius"] > 0: + keypoints = soft_argmax_refinement( + keypoints, dense_scores, self.conf["refinement_radius"] + ) + + # Convert (h, w) to (x, y) + keypoints = [torch.flip(k, [1]).float() for k in keypoints] + + if self.conf.force_num_keypoints: + keypoints = pad_and_stack( + keypoints, + max_kps, + -2, + mode="random_c", + bounds=( + 0, + data.get("image_size", torch.tensor(image.shape[-2:])) + .min() + .item(), + ), + ) + scores = pad_and_stack(scores, max_kps, -1, mode="zeros") + else: + keypoints = torch.stack(keypoints, 0) + scores = torch.stack(scores, 0) + + # Extract descriptors + if (len(keypoints) == 1) or self.conf.force_num_keypoints: + # Batch sampling of the descriptors + if self.conf.legacy_sampling: + desc = sample_descriptors(keypoints, dense_desc, 8) + else: + desc = sample_descriptors_fix_sampling(keypoints, dense_desc, 8) + else: + if self.conf.legacy_sampling: + desc = [ + sample_descriptors(k[None], d[None], 8)[0] + for k, d in zip(keypoints, dense_desc) + ] + else: + desc = [ + sample_descriptors_fix_sampling(k[None], d[None], 8)[0] + for k, d in zip(keypoints, dense_desc) + ] + + pred = { + "keypoints": keypoints + 0.5, + "descriptors": desc.transpose(-1, -2), + } + + if self.conf.dense_outputs: + pred["dense_descriptors"] = dense_desc + + return pred + + def loss(self, pred, data): + raise NotImplementedError + + def metrics(self, pred, data): + raise NotImplementedError diff --git a/third_party/gim/gim/lightglue/utils/__init__.py b/third_party/gim/gim/lightglue/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13b83c33239c952cea7ed746cf043a26c46e5109 --- /dev/null +++ b/third_party/gim/gim/lightglue/utils/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun diff --git a/third_party/gim/gim/lightglue/utils/tools.py b/third_party/gim/gim/lightglue/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..6a27f4a491e1675557b992401208bbe4c355edd2 --- /dev/null +++ b/third_party/gim/gim/lightglue/utils/tools.py @@ -0,0 +1,269 @@ +""" +Various handy Python and PyTorch utils. + +Author: Paul-Edouard Sarlin (skydes) +""" + +import os +import random +import time +from collections.abc import Iterable +from contextlib import contextmanager + +import numpy as np +import torch + + +class AverageMetric: + def __init__(self): + self._sum = 0 + self._num_examples = 0 + + def update(self, tensor): + assert tensor.dim() == 1 + tensor = tensor[~torch.isnan(tensor)] + self._sum += tensor.sum().item() + self._num_examples += len(tensor) + + def compute(self): + if self._num_examples == 0: + return np.nan + else: + return self._sum / self._num_examples + + +# same as AverageMetric, but tracks all elements +class FAverageMetric: + def __init__(self): + self._sum = 0 + self._num_examples = 0 + self._elements = [] + + def update(self, tensor): + self._elements += tensor.cpu().numpy().tolist() + assert tensor.dim() == 1 + tensor = tensor[~torch.isnan(tensor)] + self._sum += tensor.sum().item() + self._num_examples += len(tensor) + + def compute(self): + if self._num_examples == 0: + return np.nan + else: + return self._sum / self._num_examples + + +class MedianMetric: + def __init__(self): + self._elements = [] + + def update(self, tensor): + assert tensor.dim() == 1 + self._elements += tensor.cpu().numpy().tolist() + + def compute(self): + if len(self._elements) == 0: + return np.nan + else: + return np.nanmedian(self._elements) + + +class PRMetric: + def __init__(self): + self.labels = [] + self.predictions = [] + + @torch.no_grad() + def update(self, labels, predictions, mask=None): + assert labels.shape == predictions.shape + self.labels += ( + (labels[mask] if mask is not None else labels).cpu().numpy().tolist() + ) + self.predictions += ( + (predictions[mask] if mask is not None else predictions) + .cpu() + .numpy() + .tolist() + ) + + @torch.no_grad() + def compute(self): + return np.array(self.labels), np.array(self.predictions) + + def reset(self): + self.labels = [] + self.predictions = [] + + +class QuantileMetric: + def __init__(self, q=0.05): + self._elements = [] + self.q = q + + def update(self, tensor): + assert tensor.dim() == 1 + self._elements += tensor.cpu().numpy().tolist() + + def compute(self): + if len(self._elements) == 0: + return np.nan + else: + return np.nanquantile(self._elements, self.q) + + +class RecallMetric: + def __init__(self, ths, elements=[]): + self._elements = elements + self.ths = ths + + def update(self, tensor): + assert tensor.dim() == 1 + self._elements += tensor.cpu().numpy().tolist() + + def compute(self): + if isinstance(self.ths, Iterable): + return [self.compute_(th) for th in self.ths] + else: + return self.compute_(self.ths[0]) + + def compute_(self, th): + if len(self._elements) == 0: + return np.nan + else: + s = (np.array(self._elements) < th).sum() + return s / len(self._elements) + + +def cal_error_auc(errors, thresholds): + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(len(errors)) + 1) / len(errors) + errors = np.r_[0.0, errors] + recall = np.r_[0.0, recall] + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t) + r = np.r_[recall[:last_index], recall[last_index - 1]] + e = np.r_[errors[:last_index], t] + aucs.append(np.round((np.trapz(r, x=e) / t), 4)) + return aucs + + +class AUCMetric: + def __init__(self, thresholds, elements=None): + self._elements = elements + self.thresholds = thresholds + if not isinstance(thresholds, list): + self.thresholds = [thresholds] + + def update(self, tensor): + assert tensor.dim() == 1 + self._elements += tensor.cpu().numpy().tolist() + + def compute(self): + if len(self._elements) == 0: + return np.nan + else: + return cal_error_auc(self._elements, self.thresholds) + + +class Timer(object): + """A simpler timer context object. + Usage: + ``` + > with Timer('mytimer'): + > # some computations + [mytimer] Elapsed: X + ``` + """ + + def __init__(self, name=None): + self.name = name + + def __enter__(self): + self.tstart = time.time() + return self + + def __exit__(self, type, value, traceback): + self.duration = time.time() - self.tstart + if self.name is not None: + print("[%s] Elapsed: %s" % (self.name, self.duration)) + + +def get_class(mod_path, BaseClass): + """Get the class object which inherits from BaseClass and is defined in + the module named mod_name, child of base_path. + """ + import inspect + + mod = __import__(mod_path, fromlist=[""]) + classes = inspect.getmembers(mod, inspect.isclass) + # Filter classes defined in the module + classes = [c for c in classes if c[1].__module__ == mod_path] + # Filter classes inherited from BaseModel + classes = [c for c in classes if issubclass(c[1], BaseClass)] + assert len(classes) == 1, classes + return classes[0][1] + + +def set_num_threads(nt): + """Force numpy and other libraries to use a limited number of threads.""" + try: + import mkl + except ImportError: + pass + else: + mkl.set_num_threads(nt) + torch.set_num_threads(1) + os.environ["IPC_ENABLE"] = "1" + for o in [ + "OPENBLAS_NUM_THREADS", + "NUMEXPR_NUM_THREADS", + "OMP_NUM_THREADS", + "MKL_NUM_THREADS", + ]: + os.environ[o] = str(nt) + + +def set_seed(seed): + random.seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_random_state(with_cuda): + pth_state = torch.get_rng_state() + np_state = np.random.get_state() + py_state = random.getstate() + if torch.cuda.is_available() and with_cuda: + cuda_state = torch.cuda.get_rng_state_all() + else: + cuda_state = None + return pth_state, np_state, py_state, cuda_state + + +def set_random_state(state): + pth_state, np_state, py_state, cuda_state = state + torch.set_rng_state(pth_state) + np.random.set_state(np_state) + random.setstate(py_state) + if ( + cuda_state is not None + and torch.cuda.is_available() + and len(cuda_state) == torch.cuda.device_count() + ): + torch.cuda.set_rng_state_all(cuda_state) + + +@contextmanager +def fork_rng(seed=None, with_cuda=True): + state = get_random_state(with_cuda) + if seed is not None: + set_seed(seed) + try: + yield + finally: + set_random_state(state) diff --git a/third_party/gim/gim/loftr/__init__.py b/third_party/gim/gim/loftr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13b83c33239c952cea7ed746cf043a26c46e5109 --- /dev/null +++ b/third_party/gim/gim/loftr/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun diff --git a/third_party/gim/gim/loftr/backbone/__init__.py b/third_party/gim/gim/loftr/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1040aba694eeda5828ac7232e52a87ead0179a94 --- /dev/null +++ b/third_party/gim/gim/loftr/backbone/__init__.py @@ -0,0 +1,11 @@ +from .resnet import ResNetFPN_8_2 + + +def build_backbone(config): + if config['backbone_type'] == 'ResNetFPN': + if config['resolution'] == (8, 2): + return ResNetFPN_8_2(config['resnetfpn']) + elif config['resolution'] == (16, 4): + return ResNetFPN_16_4(config['resnetfpn']) + else: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") diff --git a/third_party/gim/gim/loftr/backbone/resnet.py b/third_party/gim/gim/loftr/backbone/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..526e38f10853ee6255f342b0faf57b67ab30a3f4 --- /dev/null +++ b/third_party/gim/gim/loftr/backbone/resnet.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from typing import Type, Callable, Union, List, Optional + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + # self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + # dilate=replace_stride_with_dilation[2]) + # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + # self.fc = nn.Linear(512 * block.expansion, num_classes) + # + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + # nn.init.constant_(m.weight, 1) + # nn.init.constant_(m.bias, 0) + # + # # Zero-initialize the last BN in each residual branch, + # # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + # if zero_init_residual: + # for m in self.modules(): + # if isinstance(m, Bottleneck): + # nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + # elif isinstance(m, BasicBlock): + # nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)] + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + # x = self.conv1(x) # (2, 64, 320, 320) + # x = self.bn1(x) # (2, 64, 320, 320) + # x1 = self.relu(x) # (2, 64, 320, 320) + # x2 = self.maxpool(x1) # (2, 64, 160, 160) + + # x2 = self.layer1(x1) # (2, 64, 160, 160) + # x3 = self.layer2(x2) # (2, 128, 80, 80) + # x4 = self.layer3(x3) # (2, 256, 40, 40) + # x = self.layer4(x) # (2, 512, 20, 20) + + # x = self.avgpool(x) # (2, 512, 1, 1) + # x = torch.flatten(x, 1) # (2, 512) + # x = self.fc(x) # (2, 1000) + + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + return x1, x2, x3 + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('layer4.'): state_dict.pop(k) + if k.startswith('fc.'): state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) + + +class ResNetFPN_8_2(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + # initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + # self.in_planes = initial_dim + + # Networks + # self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + # self.bn1 = nn.BatchNorm2d(initial_dim) + # self.relu = nn.ReLU(inplace=True) + + # self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + # self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + # self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + self.encode = ResNet(Bottleneck, [3, 4, 6, 3]) # resnet50 + + # 3. FPN upsample + self.layer3_outconv = conv1x1(block_dims[5], block_dims[3]) + self.layer2_outconv = conv1x1(block_dims[4], block_dims[3]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[3], block_dims[3]), + nn.BatchNorm2d(block_dims[3]), + nn.LeakyReLU(), + conv3x3(block_dims[3], block_dims[2]), + ) + self.layer1_outconv = conv1x1(block_dims[3], block_dims[2]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + # x0 = self.relu(self.bn1(self.conv1(x))) + # x1 = self.layer1(x0) # 1/2 + # x2 = self.layer2(x1) # 1/4 + # x3 = self.layer3(x2) # 1/8 + + # x1: (2, 64, 320, 320) + # x2: (2, 128, 160, 160) + # x3: (2, 256, 80, 80) + x1, x2, x3 = self.encode(x) + + # FPN + x3_out = self.layer3_outconv(x3) # (2, 256, 80, 80) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) # (2, 256, 160, 160) + x2_out = self.layer2_outconv(x2) # (2, 256, 160, 160) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) # (2, 196, 160, 160) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) # (2, 196, 320, 320) + x1_out = self.layer1_outconv(x1) # (2, 196, 320, 320) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return [x3_out, x1_out] + + +if __name__ == '__main__': + # Original form + # config = dict(initial_dim=128, block_dims=[128, 196, 256]) + # model = ResNetFPN_8_2(config) + # # output (list): + # # 0: (2, 256, 80, 80) + # # 1: (2, 128, 320, 320) + # output = model(torch.randn(2, 1, 640, 640)) + + # model = ResNet(BasicBlock, [2, 2, 2, 2]) + # # weights = torch.load('resnet18(5c106cde).ckpt', map_location='cpu') + # # model.load_state_dict(weights) + # output = model(torch.randn(2, 3, 640, 640)) + + config = dict(initial_dim=128, block_dims=[64, 128, 196, 256]) + model = ResNetFPN_8_2(config) + # output (list): + # 0: (2, 256, 80, 80) + # 1: (2, 128, 320, 320) + output = model(torch.randn(2, 3, 640, 640)) diff --git a/third_party/gim/gim/loftr/backbone/resnet_fpn.py b/third_party/gim/gim/loftr/backbone/resnet_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..18e4caf34f065aa46e05913fdccb9a93403148fc --- /dev/null +++ b/third_party/gim/gim/loftr/backbone/resnet_fpn.py @@ -0,0 +1,199 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_planes, planes, stride=1): + super().__init__() + self.conv1 = conv3x3(in_planes, planes, stride) + self.conv2 = conv3x3(planes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + conv1x1(in_planes, planes, stride=stride), + nn.BatchNorm2d(planes) + ) + + def forward(self, x): + y = x + y = self.relu(self.bn1(self.conv1(y))) + y = self.bn2(self.conv2(y)) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResNetFPN_8_2(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + # FPN + x3_out = self.layer3_outconv(x3) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return [x3_out, x1_out] + + +class ResNetFPN_16_4(nn.Module): + """ + ResNet+FPN, output resolution are 1/16 and 1/4. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(3, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16 + + # 3. FPN upsample + self.layer4_outconv = conv1x1(block_dims[3], block_dims[3]) + self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) + self.layer3_outconv2 = nn.Sequential( + conv3x3(block_dims[3], block_dims[3]), + nn.BatchNorm2d(block_dims[3]), + nn.LeakyReLU(), + conv3x3(block_dims[3], block_dims[2]), + ) + + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + x4 = self.layer4(x3) # 1/16 + + # FPN + x4_out = self.layer4_outconv(x4) + + x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) + x3_out = self.layer3_outconv(x3) + x3_out = self.layer3_outconv2(x3_out+x4_out_2x) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + return [x4_out, x2_out] diff --git a/third_party/gim/gim/loftr/config.py b/third_party/gim/gim/loftr/config.py new file mode 100644 index 0000000000000000000000000000000000000000..ee03c9b6fea8430318b9be64fa15bd2268a04704 --- /dev/null +++ b/third_party/gim/gim/loftr/config.py @@ -0,0 +1,77 @@ +from yacs.config import CfgNode as CN + +_CN = CN() +_CN.TEMP_BUG_FIX = True + +############## ↓ LoFTR Pipeline ↓ ############## +_CN.LOFTR = CN() +_CN.LOFTR.WEIGHT = None + +############## ↓ LoFTR Pipeline ↓ ############## +_CN.LOFTR.BACKBONE_TYPE = 'ResNetFPN' +_CN.LOFTR.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] +_CN.LOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd +_CN.LOFTR.FINE_CONCAT_COARSE_FEAT = False + +# 1. LoFTR-backbone (local feature CNN) config +_CN.LOFTR.RESNETFPN = CN() +_CN.LOFTR.RESNETFPN.INITIAL_DIM = 128 +_CN.LOFTR.RESNETFPN.BLOCK_DIMS = [64, 128, 196, 256, 512, 1024] # s1, s2, s3 + +# 2. LoFTR-coarse module config +_CN.LOFTR.COARSE = CN() +_CN.LOFTR.COARSE.D_MODEL = 256 +_CN.LOFTR.COARSE.NHEAD = 8 +_CN.LOFTR.COARSE.LAYER_NAMES = 4 +_CN.LOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] + +# 3. Coarse-Matching config +_CN.LOFTR.MATCH_COARSE = CN() +_CN.LOFTR.MATCH_COARSE.THR = 0.2 +_CN.LOFTR.MATCH_COARSE.BORDER_RM = 2 +_CN.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] +_CN.LOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.LOFTR.MATCH_COARSE.SKH_ITERS = 3 +_CN.LOFTR.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 +_CN.LOFTR.MATCH_COARSE.SKH_PREFILTER = False +_CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory +_CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock +_CN.LOFTR.MATCH_COARSE.SPARSE_SPVS = False + +# 4. LoFTR-fine module config +_CN.LOFTR.FINE = CN() +_CN.LOFTR.FINE.D_MODEL = 128 +_CN.LOFTR.FINE.NHEAD = 8 +_CN.LOFTR.FINE.LAYER_NAMES = 1 +_CN.LOFTR.FINE.ATTENTION = 'linear' + +# 5. LoFTR Losses +# -- # coarse-level +_CN.LOFTR.LOSS = CN() +_CN.LOFTR.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy'] +_CN.LOFTR.LOSS.COARSE_WEIGHT = 1.0 +# _CN.LOFTR.LOSS.SPARSE_SPVS = False +# -- - -- # focal loss (coarse) +_CN.LOFTR.LOSS.FOCAL_ALPHA = 0.25 +_CN.LOFTR.LOSS.FOCAL_GAMMA = 2.0 +_CN.LOFTR.LOSS.POS_WEIGHT = 1.0 +_CN.LOFTR.LOSS.NEG_WEIGHT = 1.0 +# _CN.LOFTR.LOSS.DUAL_SOFTMAX = False # whether coarse-level use dual-softmax or not. +# use `_CN.LOFTR.MATCH_COARSE.MATCH_TYPE` + +# -- # fine-level +_CN.LOFTR.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2'] +_CN.LOFTR.LOSS.FINE_WEIGHT = 1.0 +_CN.LOFTR.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window) + +# Overlap +_CN.LOFTR.LOSS.OVERLAP_WEIGHT = 20.0 +_CN.LOFTR.LOSS.OVERLAP_FOCAL_ALPHA = 0.25 +_CN.LOFTR.LOSS.OVERLAP_FOCAL_GAMMA = 5.0 + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _CN.clone() diff --git a/third_party/gim/gim/loftr/configs/__init__.py b/third_party/gim/gim/loftr/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13b83c33239c952cea7ed746cf043a26c46e5109 --- /dev/null +++ b/third_party/gim/gim/loftr/configs/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun diff --git a/third_party/gim/gim/loftr/configs/outdoor/__init__.py b/third_party/gim/gim/loftr/configs/outdoor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..846998ecc3b1961b957a1a98afa9f1a899079ee4 --- /dev/null +++ b/third_party/gim/gim/loftr/configs/outdoor/__init__.py @@ -0,0 +1,12 @@ +from networks.loftr.config import get_cfg_defaults as get_network_cfg +from trainer.config import get_cfg_defaults as get_trainer_cfg + +# network +network_cfg = get_network_cfg() +network_cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 + +# optimizer +trainer_cfg = get_trainer_cfg() +trainer_cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs +trainer_cfg.TRAINER.WARMUP_RATIO = 0.1 +trainer_cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] diff --git a/third_party/gim/gim/loftr/loftr.py b/third_party/gim/gim/loftr/loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..2ad5a1aeef225724b0bd35befb92eb85417a5106 --- /dev/null +++ b/third_party/gim/gim/loftr/loftr.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +from einops.einops import rearrange + +from .backbone import build_backbone +from .utils.position_encoding import PositionEncodingSine +from .submodules import LocalFeatureTransformer, FinePreprocess +import warnings +from .utils.coarse_matching import CoarseMatching +warnings.simplefilter("ignore", UserWarning) +from .utils.fine_matching import FineMatching + + +class LoFTR(nn.Module): + def __init__(self, config): + super().__init__() + # Misc + self.config = config + + # Modules + self.backbone = build_backbone(config) + self.pos_encoding = PositionEncodingSine( + config['coarse']['d_model'], + temp_bug_fix=False) + self.loftr_coarse = LocalFeatureTransformer(config['coarse']) + self.coarse_matching = CoarseMatching(config['match_coarse']) + self.fine_preprocess = FinePreprocess(config) + self.loftr_fine = LocalFeatureTransformer(config["fine"]) + self.fine_matching = FineMatching() + + """ + outdoor_ds.ckpt: {OrderedDict: 211} + backbone: {OrderedDict: 107} + loftr_coarse: {OrderedDict: 80} + loftr_fine: {OrderedDict: 20} + fine_preprocess: {OrderedDict: 4} + """ + # if config['weight'] is not None: + # weights = torch.load(config['weight'], map_location='cpu')['state_dict'] + # self.load_state_dict(weights) + # print(config['weight'] + ' load success.') + + def forward(self, data): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + + if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + feats_c, feats_f = self.backbone(torch.cat([data['color0'], data['color1']], dim=0)) # h == h0 == h1, w == w0 == w1feats_c: (bs*2, 256, h//8, w//8), feats_f: (bs*2, 128, h//2, w//2) + (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) # feat_c0, feat_c1: (bs, 256, h//8, w//8), feat_f0, feat_f1: (bs, 128, h//2, w//2) + else: # handle different input shapes + (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['color0']), self.backbone(data['color1']) + + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] + }) + + # 2. coarse-level loftr module + b, c, h0, w0 = feat_c0.size() + _, _, h1, w1 = feat_c1.size() + # add featmap with positional encoding, then flatten it to sequence [N, HW, C] + feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c') + feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c') + + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) + + # 3. match coarse-level + self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) + + # 4. fine-level refinement + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) + if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted + feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) + + # 5. match fine-level + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('model.'): + state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) + if k.startswith('matcher.'): + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) diff --git a/third_party/gim/gim/loftr/misc.py b/third_party/gim/gim/loftr/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..61cd57bf1e4e5aacab58e42e9277a4ad12990dc9 --- /dev/null +++ b/third_party/gim/gim/loftr/misc.py @@ -0,0 +1,100 @@ +import os +import contextlib +import joblib +from typing import Union +from loguru import _Logger, logger +from itertools import chain + +import torch +from yacs.config import CfgNode as CN +from pytorch_lightning.utilities import rank_zero_only + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +def upper_config(dict_cfg): + if not isinstance(dict_cfg, dict): + return dict_cfg + return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} + + +def log_on(condition, message, level): + if condition: + assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] + logger.log(level, message) + + +def get_rank_zero_only_logger(logger: _Logger): + if rank_zero_only.rank == 0: + return logger + else: + for _level in logger._core.levels.keys(): + level = _level.lower() + setattr(logger, level, + lambda x: None) + logger._log = lambda x: None + return logger + + +def setup_gpus(gpus: Union[str, int]) -> int: + """ A temporary fix for pytorch-lighting 1.3.x """ + gpus = str(gpus) + gpu_ids = [] + + if ',' not in gpus: + n_gpus = int(gpus) + return n_gpus if n_gpus != -1 else torch.cuda.device_count() + else: + gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] + + # setup environment variables + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + if visible_devices is None: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) + visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') + logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') + else: + logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') + return len(gpu_ids) + + +def flattenList(x): + return list(chain(*x)) + + +@contextlib.contextmanager +def tqdm_joblib(tqdm_object): + """Context manager to patch joblib to report into tqdm progress bar given as argument + + Usage: + with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: + Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) + + When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) + ret_vals = Parallel(n_jobs=args.world_size)( + delayed(lambda x: _compute_cov_score(pid, *x))(param) + for param in tqdm(combinations(image_ids, 2), + desc=f'Computing cov_score of [{pid}]', + total=len(image_ids)*(len(image_ids)-1)/2)) + Src: https://stackoverflow.com/a/58936697 + """ + class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_batch_callback = joblib.parallel.BatchCompletionCallBack + joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + try: + yield tqdm_object + finally: + joblib.parallel.BatchCompletionCallBack = old_batch_callback + tqdm_object.close() diff --git a/third_party/gim/gim/loftr/submodules/__init__.py b/third_party/gim/gim/loftr/submodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca51db4f50a0c4f3dcd795e74b83e633ab2e990a --- /dev/null +++ b/third_party/gim/gim/loftr/submodules/__init__.py @@ -0,0 +1,2 @@ +from .transformer import LocalFeatureTransformer +from .fine_preprocess import FinePreprocess diff --git a/third_party/gim/gim/loftr/submodules/attentions.py b/third_party/gim/gim/loftr/submodules/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..b73c5a6a6a722a44c0b68f70cb77c0988b8a5fb3 --- /dev/null +++ b/third_party/gim/gim/loftr/submodules/attentions.py @@ -0,0 +1,81 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module, Dropout + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + if kv_mask is not None: + QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + if self.use_dropout: + A = self.dropout(A) + + queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) + + return queried_values.contiguous() diff --git a/third_party/gim/gim/loftr/submodules/fine_preprocess.py b/third_party/gim/gim/loftr/submodules/fine_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb8eefd362240a9901a335f0e6e07770ff04567 --- /dev/null +++ b/third_party/gim/gim/loftr/submodules/fine_preprocess.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange, repeat + + +class FinePreprocess(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.cat_c_feat = config['fine_concat_coarse_feat'] + self.W = self.config['fine_window_size'] + + d_model_c = self.config['coarse']['d_model'] + d_model_f = self.config['fine']['d_model'] + self.d_model_f = d_model_f + if self.cat_c_feat: + self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) + self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") + + def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): + W = self.W + stride = data['hw0_f'][0] // data['hw0_c'][0] + + data.update({'W': W}) + if data['b_ids'].shape[0] == 0: + feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + return feat0, feat1 + + # 1. unfold(crop) all local windows + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + + # option: use coarse-level loftr feature as context: concat and linear + if self.cat_c_feat: + feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], + feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] + feat_cf_win = self.merge_feat(torch.cat([ + torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] + repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] + ], -1)) + feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) + + return feat_f0_unfold, feat_f1_unfold diff --git a/third_party/gim/gim/loftr/submodules/transformer.py b/third_party/gim/gim/loftr/submodules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e70cafddc912901a04d2491bf6f9e9dbaaf4e793 --- /dev/null +++ b/third_party/gim/gim/loftr/submodules/transformer.py @@ -0,0 +1,101 @@ +import copy +import torch +import torch.nn as nn +from .attentions import LinearAttention, FullAttention + + +class LoFTREncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear'): + super(LoFTREncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = ['self', 'cross'] * config['layer_names'] + encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" + + for layer, name in zip(self.layers, self.layer_names): + if name == 'self': + feat0 = layer(feat0, feat0, mask0, mask0) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0) + else: + raise KeyError + + return feat0, feat1 diff --git a/third_party/gim/gim/loftr/utils/__init__.py b/third_party/gim/gim/loftr/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13b83c33239c952cea7ed746cf043a26c46e5109 --- /dev/null +++ b/third_party/gim/gim/loftr/utils/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun diff --git a/third_party/gim/gim/loftr/utils/coarse_matching.py b/third_party/gim/gim/loftr/utils/coarse_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..8f225ed3dcb6becd229302ece53d8cc8b43e42f0 --- /dev/null +++ b/third_party/gim/gim/loftr/utils/coarse_matching.py @@ -0,0 +1,259 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange + +INF = 1e9 + + +def mask_border(m, b: int, v): + """ Mask borders with value + Args: + m (torch.Tensor): [N, H0, W0, H1, W1] + b (int) + v (m.dtype) + """ + if b <= 0: + return + + m[:, :b] = v + m[:, :, :b] = v + m[:, :, :, :b] = v + m[:, :, :, :, :b] = v + m[:, -b:] = v + m[:, :, -b:] = v + m[:, :, :, -b:] = v + m[:, :, :, :, -b:] = v + + +def mask_border_with_padding(m, bd, v, p_m0, p_m1): + if bd <= 0: + return + + m[:, :bd] = v + m[:, :, :bd] = v + m[:, :, :, :bd] = v + m[:, :, :, :, :bd] = v + + h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() + h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() + for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): + m[b_idx, h0 - bd:] = v + m[b_idx, :, w0 - bd:] = v + m[b_idx, :, :, h1 - bd:] = v + m[b_idx, :, :, :, w1 - bd:] = v + + +def compute_max_candidates(p_m0, p_m1): + """Compute the max candidates of all pairs within a batch + + Args: + p_m0, p_m1 (torch.Tensor): padded masks + """ + h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] + h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] + max_cand = torch.sum( + torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + return max_cand + + +class CoarseMatching(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # general config + self.thr = config['thr'] + self.border_rm = config['border_rm'] + # -- # for trainig fine-level LoFTR + self.train_coarse_percent = config['train_coarse_percent'] + self.train_pad_num_gt_min = config['train_pad_num_gt_min'] + + # we provide 2 options for differentiable matching + self.match_type = config['match_type'] + if self.match_type == 'dual_softmax': + self.temperature = config['dsmax_temperature'] + elif self.match_type == 'sinkhorn': + try: + from .superglue import log_optimal_transport + except ImportError: + raise ImportError("download superglue.py first!") + self.log_optimal_transport = log_optimal_transport + self.bin_score = nn.Parameter( + torch.tensor(config['skh_init_bin_score'], requires_grad=True)) + self.skh_iters = config['skh_iters'] + self.skh_prefilter = config['skh_prefilter'] + else: + raise NotImplementedError() + + def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): + """ + Args: + feat_c0 (torch.Tensor): [N, L, C] + feat_c1 (torch.Tensor): [N, S, C] + data (dict) + mask_c0 (torch.Tensor): [N, L] (optional) + mask_c1 (torch.Tensor): [N, S] (optional) + Update: + data (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + NOTE: M' != M during training. + """ + # noinspection PyArgumentList + N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) + + # normalize + feat_c0, feat_c1 = map(lambda feat: feat/feat.shape[-1]**.5, [feat_c0, feat_c1]) + + conf_matrix = None + if self.match_type == 'dual_softmax': + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)/self.temperature + if mask_c0 is not None: + sim_matrix.masked_fill_(~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF) + conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) + + elif self.match_type == 'sinkhorn': + # sinkhorn, dustbin included + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) + if mask_c0 is not None: + sim_matrix[:, :L, :S].masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + + # build uniform prior & use sinkhorn + log_assign_matrix = self.log_optimal_transport( + sim_matrix, self.bin_score, self.skh_iters) + assign_matrix = log_assign_matrix.exp() + conf_matrix = assign_matrix[:, :-1, :-1] + + # filter prediction with dustbin score (only in evaluation mode) + if not self.training and self.skh_prefilter: + filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L] + filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S] + conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0 + conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0 + + if self.config['sparse_spvs']: + data.update({'conf_matrix_with_bin': assign_matrix.clone()}) + + data.update({'conf_matrix': conf_matrix}) + + # predict coarse matches from conf_matrix + data.update(**self.get_coarse_match(conf_matrix, data)) + + @torch.no_grad() + def get_coarse_match(self, conf_matrix, data): + """ + Args: + conf_matrix (torch.Tensor): [N, L, S] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + axes_lengths = { + 'h0c': data['hw0_c'][0], + 'w0c': data['hw0_c'][1], + 'h1c': data['hw1_c'][0], + 'w1c': data['hw1_c'][1] + } + _device = conf_matrix.device + # 1. confidence thresholding + mask = conf_matrix > self.thr + mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', + **axes_lengths) + if 'mask0' not in data: + mask_border(mask, self.border_rm, False) + else: + mask_border_with_padding(mask, self.border_rm, False, + data['mask0'], data['mask1']) + mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', + **axes_lengths) + + # 2. mutual nearest + mask = mask \ + * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ + * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) + + # 3. find all valid coarse matches + # this only works when at most one `True` in each row + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + mconf = conf_matrix[b_ids, i_ids, j_ids] + + # 4. Random sampling of training samples for fine-level LoFTR + # (optional) pad samples with gt coarse-level matches + if self.training: + # NOTE: + # The sampling is performed across all pairs in a batch without manually balancing + # #samples for fine-level increases w.r.t. batch_size + if 'mask0' not in data: + num_candidates_max = mask.size(0) * max( + mask.size(1), mask.size(2)) + else: + num_candidates_max = compute_max_candidates( + data['mask0'], data['mask1']) + num_matches_train = int(num_candidates_max * + self.train_coarse_percent) + num_matches_pred = len(b_ids) + assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" + + # pred_indices is to select from prediction + if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: + pred_indices = torch.arange(num_matches_pred, device=_device) + else: + pred_indices = torch.randint( + num_matches_pred, + (num_matches_train - self.train_pad_num_gt_min, ), + device=_device) + + # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) + gt_pad_indices = torch.randint( + len(data['spv_b_ids']), + (max(num_matches_train - num_matches_pred, self.train_pad_num_gt_min), ), + device=_device) + mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + + b_ids, i_ids, j_ids, mconf = map( + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], + dim=0), + *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], + [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + + # These matches select patches that feed into fine-level network + coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + + # 4. Update with matches in original image resolution + scale = data['hw0_i'][0] / data['hw0_c'][0] + scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + mkpts0_c = torch.stack( + [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], + dim=1) * scale0 + mkpts1_c = torch.stack( + [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], + dim=1) * scale1 + + # These matches is the current prediction (for visualization) + coarse_matches.update({ + 'gt_mask': mconf == 0, + 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches + 'mkpts0_c': mkpts0_c[mconf != 0], + 'mkpts1_c': mkpts1_c[mconf != 0], + 'mconf': mconf[mconf != 0] + }) + + return coarse_matches diff --git a/third_party/gim/gim/loftr/utils/fine_matching.py b/third_party/gim/gim/loftr/utils/fine_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a0bf6096963df69e088ed826ea334d3114c67c --- /dev/null +++ b/third_party/gim/gim/loftr/utils/fine_matching.py @@ -0,0 +1,74 @@ +import math +import torch +import torch.nn as nn + +from kornia.geometry.subpix import dsnt +from kornia.utils.grid import create_meshgrid + + +class FineMatching(nn.Module): + """FineMatching with s2d paradigm""" + + def __init__(self): + super().__init__() + + def forward(self, feat_f0, feat_f1, data): + """ + Args: + feat_f0 (torch.Tensor): [M, WW, C] + feat_f1 (torch.Tensor): [M, WW, C] + data (dict) + Update: + data (dict):{ + 'expec_f' (torch.Tensor): [M, 3], + 'mkpts0_f' (torch.Tensor): [M, 2], + 'mkpts1_f' (torch.Tensor): [M, 2]} + """ + M, WW, C = feat_f0.shape + W = int(math.sqrt(WW)) + scale = data['hw0_i'][0] / data['hw0_f'][0] + self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale + + # corner case: if no coarse matches found + if M == 0: + assert self.training is False, "M is always >0, when training, see coarse_matching.py" + # logger.warning('No matches found in coarse-level.') + data.update({ + 'expec_f': torch.empty(0, 3, device=feat_f0.device), + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + + feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] + sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) + softmax_temp = 1. / C**.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) + + # compute coordinates from heatmap + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] + + # compute std over + var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability + + # for fine-level supervision + data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) + + # compute absolute kpt coords + self.get_fine_match(coords_normalized, data) + + @torch.no_grad() + def get_fine_match(self, coords_normed, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + + # mkpts0_f and mkpts1_f + mkpts0_f = data['mkpts0_c'] + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale + mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] + + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) diff --git a/third_party/gim/gim/loftr/utils/position_encoding.py b/third_party/gim/gim/loftr/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..8a835f145d12d9024da341ca0cd53ad6ec9412d8 --- /dev/null +++ b/third_party/gim/gim/loftr/utils/position_encoding.py @@ -0,0 +1,43 @@ +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), + the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact + on the final performance. For now, we keep both impls for backward compatability. + We will remove the buggy impl after re-training all variants of our released models. + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + + if temp_bug_fix: + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) + else: # a buggy implementation (for backward compatability only) + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + return x + self.pe[:, :, :x.size(2), :x.size(3)] diff --git a/third_party/gim/gim/mit_semseg/__init__.py b/third_party/gim/gim/mit_semseg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ccf2d50ca12d8706b47e69d172f66ccd19f3dd --- /dev/null +++ b/third_party/gim/gim/mit_semseg/__init__.py @@ -0,0 +1,5 @@ +""" +MIT CSAIL Semantic Segmentation +""" + +__version__ = '1.0.0' diff --git a/third_party/gim/gim/mit_semseg/config/__init__.py b/third_party/gim/gim/mit_semseg/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7cfcbb8ae15ef50207c000d4c838a5b68b9c43 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/config/__init__.py @@ -0,0 +1 @@ +from .defaults import _C as cfg diff --git a/third_party/gim/gim/mit_semseg/config/defaults.py b/third_party/gim/gim/mit_semseg/config/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..83818ce04fce587eae76fef00e181b63541add6f --- /dev/null +++ b/third_party/gim/gim/mit_semseg/config/defaults.py @@ -0,0 +1,97 @@ +from yacs.config import CfgNode as CN + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = CN() +_C.DIR = "ckpt/ade20k-resnet50dilated-ppm_deepsup" + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASET = CN() +_C.DATASET.root_dataset = "./data/" +_C.DATASET.list_train = "./data/training.odgt" +_C.DATASET.list_val = "./data/validation.odgt" +_C.DATASET.num_class = 150 +# multiscale train/test, size of short edge (int or tuple) +_C.DATASET.imgSizes = (300, 375, 450, 525, 600) +# maximum input image size of long edge +_C.DATASET.imgMaxSize = 1000 +# maxmimum downsampling rate of the network +_C.DATASET.padding_constant = 8 +# downsampling rate of the segmentation label +_C.DATASET.segm_downsampling_rate = 8 +# randomly horizontally flip images when train/test +_C.DATASET.random_flip = True + +# ----------------------------------------------------------------------------- +# Model +# ----------------------------------------------------------------------------- +_C.MODEL = CN() +# architecture of net_encoder +_C.MODEL.arch_encoder = "resnet50dilated" +# architecture of net_decoder +_C.MODEL.arch_decoder = "ppm_deepsup" +# weights to finetune net_encoder +_C.MODEL.weights_encoder = "" +# weights to finetune net_decoder +_C.MODEL.weights_decoder = "" +# number of feature channels between encoder and decoder +_C.MODEL.fc_dim = 2048 + +# ----------------------------------------------------------------------------- +# Training +# ----------------------------------------------------------------------------- +_C.TRAIN = CN() +_C.TRAIN.batch_size_per_gpu = 2 +# epochs to train for +_C.TRAIN.num_epoch = 20 +# epoch to start training. useful if continue from a checkpoint +_C.TRAIN.start_epoch = 0 +# iterations of each epoch (irrelevant to batch size) +_C.TRAIN.epoch_iters = 5000 + +_C.TRAIN.optim = "SGD" +_C.TRAIN.lr_encoder = 0.02 +_C.TRAIN.lr_decoder = 0.02 +# power in poly to drop LR +_C.TRAIN.lr_pow = 0.9 +# momentum for sgd, beta1 for adam +_C.TRAIN.beta1 = 0.9 +# weights regularizer +_C.TRAIN.weight_decay = 1e-4 +# the weighting of deep supervision loss +_C.TRAIN.deep_sup_scale = 0.4 +# fix bn params, only under finetuning +_C.TRAIN.fix_bn = False +# number of data loading workers +_C.TRAIN.workers = 16 + +# frequency to display +_C.TRAIN.disp_iter = 20 +# manual seed +_C.TRAIN.seed = 304 + +# ----------------------------------------------------------------------------- +# Validation +# ----------------------------------------------------------------------------- +_C.VAL = CN() +# currently only supports 1 +_C.VAL.batch_size = 1 +# output visualization during validation +_C.VAL.visualize = False +# the checkpoint to evaluate on +_C.VAL.checkpoint = "epoch_20.pth" + +# ----------------------------------------------------------------------------- +# Testing +# ----------------------------------------------------------------------------- +_C.TEST = CN() +# currently only supports 1 +_C.TEST.batch_size = 1 +# the checkpoint to test on +_C.TEST.checkpoint = "epoch_20.pth" +# folder to output visualization results +_C.TEST.result = "./" diff --git a/third_party/gim/gim/mit_semseg/dataset.py b/third_party/gim/gim/mit_semseg/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1657446301613b71c7b213accac67b650766d6ca --- /dev/null +++ b/third_party/gim/gim/mit_semseg/dataset.py @@ -0,0 +1,296 @@ +import os +import json +import torch +from torchvision import transforms +import numpy as np +from PIL import Image + + +def imresize(im, size, interp='bilinear'): + if interp == 'nearest': + resample = Image.NEAREST + elif interp == 'bilinear': + resample = Image.BILINEAR + elif interp == 'bicubic': + resample = Image.BICUBIC + else: + raise Exception('resample method undefined!') + + return im.resize(size, resample) + + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, odgt, opt, **kwargs): + # parse options + self.imgSizes = opt.imgSizes + self.imgMaxSize = opt.imgMaxSize + # max down sampling rate of network to avoid rounding during conv or pooling + self.padding_constant = opt.padding_constant + + # parse the input list + self.parse_input_list(odgt, **kwargs) + + # mean and std + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1): + if isinstance(odgt, list): + self.list_sample = odgt + elif isinstance(odgt, str): + self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] + + if max_sample > 0: + self.list_sample = self.list_sample[0:max_sample] + if start_idx >= 0 and end_idx >= 0: # divide file list + self.list_sample = self.list_sample[start_idx:end_idx] + + self.num_sample = len(self.list_sample) + assert self.num_sample > 0 + print('# samples: {}'.format(self.num_sample)) + + def img_transform(self, img): + # 0-255 to 0-1 + img = np.float32(np.array(img)) / 255. + img = img.transpose((2, 0, 1)) + img = self.normalize(torch.from_numpy(img.copy())) + return img + + def segm_transform(self, segm): + # to tensor, -1 to 149 + segm = torch.from_numpy(np.array(segm)).long() - 1 + return segm + + # Round x to the nearest multiple of p and x' >= x + def round2nearest_multiple(self, x, p): + return ((x - 1) // p + 1) * p + + +class TrainDataset(BaseDataset): + def __init__(self, root_dataset, odgt, opt, batch_per_gpu=1, **kwargs): + super(TrainDataset, self).__init__(odgt, opt, **kwargs) + self.root_dataset = root_dataset + # down sampling rate of segm labe + self.segm_downsampling_rate = opt.segm_downsampling_rate + self.batch_per_gpu = batch_per_gpu + + # classify images into two classes: 1. h > w and 2. h <= w + self.batch_record_list = [[], []] + + # override dataset length when trainig with batch_per_gpu > 1 + self.cur_idx = 0 + self.if_shuffled = False + + def _get_sub_batch(self): + while True: + # get a sample record + this_sample = self.list_sample[self.cur_idx] + if this_sample['height'] > this_sample['width']: + self.batch_record_list[0].append(this_sample) # h > w, go to 1st class + else: + self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class + + # update current sample pointer + self.cur_idx += 1 + if self.cur_idx >= self.num_sample: + self.cur_idx = 0 + np.random.shuffle(self.list_sample) + + if len(self.batch_record_list[0]) == self.batch_per_gpu: + batch_records = self.batch_record_list[0] + self.batch_record_list[0] = [] + break + elif len(self.batch_record_list[1]) == self.batch_per_gpu: + batch_records = self.batch_record_list[1] + self.batch_record_list[1] = [] + break + return batch_records + + def __getitem__(self, index): + # NOTE: random shuffle for the first time. shuffle in __init__ is useless + if not self.if_shuffled: + np.random.seed(index) + np.random.shuffle(self.list_sample) + self.if_shuffled = True + + # get sub-batch candidates + batch_records = self._get_sub_batch() + + # resize all images' short edges to the chosen size + if isinstance(self.imgSizes, list) or isinstance(self.imgSizes, tuple): + this_short_size = np.random.choice(self.imgSizes) + else: + this_short_size = self.imgSizes + + # calculate the BATCH's height and width + # since we concat more than one samples, the batch's h and w shall be larger than EACH sample + batch_widths = np.zeros(self.batch_per_gpu, np.int32) + batch_heights = np.zeros(self.batch_per_gpu, np.int32) + for i in range(self.batch_per_gpu): + img_height, img_width = batch_records[i]['height'], batch_records[i]['width'] + this_scale = min( + this_short_size / min(img_height, img_width), \ + self.imgMaxSize / max(img_height, img_width)) + batch_widths[i] = img_width * this_scale + batch_heights[i] = img_height * this_scale + + # Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w' + batch_width = np.max(batch_widths) + batch_height = np.max(batch_heights) + batch_width = int(self.round2nearest_multiple(batch_width, self.padding_constant)) + batch_height = int(self.round2nearest_multiple(batch_height, self.padding_constant)) + + assert self.padding_constant >= self.segm_downsampling_rate, \ + 'padding constant must be equal or large than segm downsamping rate' + batch_images = torch.zeros( + self.batch_per_gpu, 3, batch_height, batch_width) + batch_segms = torch.zeros( + self.batch_per_gpu, + batch_height // self.segm_downsampling_rate, + batch_width // self.segm_downsampling_rate).long() + + for i in range(self.batch_per_gpu): + this_record = batch_records[i] + + # load image and label + image_path = os.path.join(self.root_dataset, this_record['fpath_img']) + segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) + + img = Image.open(image_path).convert('RGB') + segm = Image.open(segm_path) + assert(segm.mode == "L") + assert(img.size[0] == segm.size[0]) + assert(img.size[1] == segm.size[1]) + + # random_flip + if np.random.choice([0, 1]): + img = img.transpose(Image.FLIP_LEFT_RIGHT) + segm = segm.transpose(Image.FLIP_LEFT_RIGHT) + + # note that each sample within a mini batch has different scale param + img = imresize(img, (batch_widths[i], batch_heights[i]), interp='bilinear') + segm = imresize(segm, (batch_widths[i], batch_heights[i]), interp='nearest') + + # further downsample seg label, need to avoid seg label misalignment + segm_rounded_width = self.round2nearest_multiple(segm.size[0], self.segm_downsampling_rate) + segm_rounded_height = self.round2nearest_multiple(segm.size[1], self.segm_downsampling_rate) + segm_rounded = Image.new('L', (segm_rounded_width, segm_rounded_height), 0) + segm_rounded.paste(segm, (0, 0)) + segm = imresize( + segm_rounded, + (segm_rounded.size[0] // self.segm_downsampling_rate, \ + segm_rounded.size[1] // self.segm_downsampling_rate), \ + interp='nearest') + + # image transform, to torch float tensor 3xHxW + img = self.img_transform(img) + + # segm transform, to torch long tensor HxW + segm = self.segm_transform(segm) + + # put into batch arrays + batch_images[i][:, :img.shape[1], :img.shape[2]] = img + batch_segms[i][:segm.shape[0], :segm.shape[1]] = segm + + output = dict() + output['img_data'] = batch_images + output['seg_label'] = batch_segms + return output + + def __len__(self): + return int(1e10) # It's a fake length due to the trick that every loader maintains its own list + #return self.num_sampleclass + + +class ValDataset(BaseDataset): + def __init__(self, root_dataset, odgt, opt, **kwargs): + super(ValDataset, self).__init__(odgt, opt, **kwargs) + self.root_dataset = root_dataset + + def __getitem__(self, index): + this_record = self.list_sample[index] + # load image and label + image_path = os.path.join(self.root_dataset, this_record['fpath_img']) + segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) + img = Image.open(image_path).convert('RGB') + segm = Image.open(segm_path) + assert(segm.mode == "L") + assert(img.size[0] == segm.size[0]) + assert(img.size[1] == segm.size[1]) + + ori_width, ori_height = img.size + + img_resized_list = [] + for this_short_size in self.imgSizes: + # calculate target height and width + scale = min(this_short_size / float(min(ori_height, ori_width)), + self.imgMaxSize / float(max(ori_height, ori_width))) + target_height, target_width = int(ori_height * scale), int(ori_width * scale) + + # to avoid rounding in network + target_width = self.round2nearest_multiple(target_width, self.padding_constant) + target_height = self.round2nearest_multiple(target_height, self.padding_constant) + + # resize images + img_resized = imresize(img, (target_width, target_height), interp='bilinear') + + # image transform, to torch float tensor 3xHxW + img_resized = self.img_transform(img_resized) + img_resized = torch.unsqueeze(img_resized, 0) + img_resized_list.append(img_resized) + + # segm transform, to torch long tensor HxW + segm = self.segm_transform(segm) + batch_segms = torch.unsqueeze(segm, 0) + + output = dict() + output['img_ori'] = np.array(img) + output['img_data'] = [x.contiguous() for x in img_resized_list] + output['seg_label'] = batch_segms.contiguous() + output['info'] = this_record['fpath_img'] + return output + + def __len__(self): + return self.num_sample + + +class TestDataset(BaseDataset): + def __init__(self, odgt, opt, **kwargs): + super(TestDataset, self).__init__(odgt, opt, **kwargs) + + def __getitem__(self, index): + this_record = self.list_sample[index] + # load image + image_path = this_record['fpath_img'] + img = Image.open(image_path).convert('RGB') + + ori_width, ori_height = img.size + + img_resized_list = [] + for this_short_size in self.imgSizes: + # calculate target height and width + scale = min(this_short_size / float(min(ori_height, ori_width)), + self.imgMaxSize / float(max(ori_height, ori_width))) + target_height, target_width = int(ori_height * scale), int(ori_width * scale) + + # to avoid rounding in network + target_width = self.round2nearest_multiple(target_width, self.padding_constant) + target_height = self.round2nearest_multiple(target_height, self.padding_constant) + + # resize images + img_resized = imresize(img, (target_width, target_height), interp='bilinear') + + # image transform, to torch float tensor 3xHxW + img_resized = self.img_transform(img_resized) + img_resized = torch.unsqueeze(img_resized, 0) + img_resized_list.append(img_resized) + + output = dict() + output['img_ori'] = np.array(img) + output['img_data'] = [x.contiguous() for x in img_resized_list] + output['info'] = this_record['fpath_img'] + return output + + def __len__(self): + return self.num_sample diff --git a/third_party/gim/gim/mit_semseg/lib/__init__.py b/third_party/gim/gim/mit_semseg/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/gim/gim/mit_semseg/lib/nn/__init__.py b/third_party/gim/gim/mit_semseg/lib/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..98a96370ef04570f516052bb73f568d0ebc346c3 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/nn/__init__.py @@ -0,0 +1,2 @@ +from .modules import * +from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to diff --git a/third_party/gim/gim/mit_semseg/lib/nn/modules/__init__.py b/third_party/gim/gim/mit_semseg/lib/nn/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8709d92c610b36e0bcbd7da20c1eb41dc8cfcf --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/nn/modules/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/third_party/gim/gim/mit_semseg/lib/nn/modules/batchnorm.py b/third_party/gim/gim/mit_semseg/lib/nn/modules/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..18318965335b37cc671004a6aceda3229dc7b477 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/nn/modules/batchnorm.py @@ -0,0 +1,329 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + # customed batch norm statistics + self._moving_average_fraction = 1. - momentum + self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features)) + self.register_buffer('_tmp_running_var', torch.ones(self.num_features)) + self.register_buffer('_running_iter', torch.ones(1)) + self._tmp_running_mean = self.running_mean.clone() * self._running_iter + self._tmp_running_var = self.running_var.clone() * self._running_iter + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0): + """return *dest* by `dest := dest*alpha + delta*beta + bias`""" + return dest * alpha + delta * beta + bias + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction) + self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction) + self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction) + + self.running_mean = self._tmp_running_mean / self._running_iter + self.running_var = self._tmp_running_var / self._running_iter + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/third_party/gim/gim/mit_semseg/lib/nn/modules/comm.py b/third_party/gim/gim/mit_semseg/lib/nn/modules/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..b64bf6ba3b3e7abbab375c6dd4a87d8239e62138 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/nn/modules/comm.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/third_party/gim/gim/mit_semseg/lib/nn/modules/replicate.py b/third_party/gim/gim/mit_semseg/lib/nn/modules/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/nn/modules/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/third_party/gim/gim/mit_semseg/lib/nn/modules/tests/__init__.py b/third_party/gim/gim/mit_semseg/lib/nn/modules/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/gim/gim/mit_semseg/lib/nn/modules/tests/test_numeric_batchnorm.py b/third_party/gim/gim/mit_semseg/lib/nn/modules/tests/test_numeric_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd45a930d3dc84912e58659ee575be08e9038f0 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/nn/modules/tests/test_numeric_batchnorm.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# File : test_numeric_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm.unittest import TorchTestCase + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +class NumericTestCase(TorchTestCase): + def testNumericBatchNorm(self): + a = torch.rand(16, 10) + bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) + bn.train() + + a_var1 = Variable(a, requires_grad=True) + b_var1 = bn(a_var1) + loss1 = b_var1.sum() + loss1.backward() + + a_var2 = Variable(a, requires_grad=True) + a_mean2 = a_var2.mean(dim=0, keepdim=True) + a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) + # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) + b_var2 = (a_var2 - a_mean2) / a_std2 + loss2 = b_var2.sum() + loss2.backward() + + self.assertTensorClose(bn.running_mean, a.mean(dim=0)) + self.assertTensorClose(bn.running_var, handy_var(a)) + self.assertTensorClose(a_var1.data, a_var2.data) + self.assertTensorClose(b_var1.data, b_var2.data) + self.assertTensorClose(a_var1.grad, a_var2.grad) + + +if __name__ == '__main__': + unittest.main() diff --git a/third_party/gim/gim/mit_semseg/lib/nn/modules/tests/test_sync_batchnorm.py b/third_party/gim/gim/mit_semseg/lib/nn/modules/tests/test_sync_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..45bb3c8cfd36d8f668e6fde756b17587eab72082 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/nn/modules/tests/test_sync_batchnorm.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# File : test_sync_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback +from sync_batchnorm.unittest import TorchTestCase + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +def _find_bn(module): + for m in module.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): + return m + + +class SyncTestCase(TorchTestCase): + def _syncParameters(self, bn1, bn2): + bn1.reset_parameters() + bn2.reset_parameters() + if bn1.affine and bn2.affine: + bn2.weight.data.copy_(bn1.weight.data) + bn2.bias.data.copy_(bn1.bias.data) + + def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): + """Check the forward and backward for the customized batch normalization.""" + bn1.train(mode=is_train) + bn2.train(mode=is_train) + + if cuda: + input = input.cuda() + + self._syncParameters(_find_bn(bn1), _find_bn(bn2)) + + input1 = Variable(input, requires_grad=True) + output1 = bn1(input1) + output1.sum().backward() + input2 = Variable(input, requires_grad=True) + output2 = bn2(input2) + output2.sum().backward() + + self.assertTensorClose(input1.data, input2.data) + self.assertTensorClose(output1.data, output2.data) + self.assertTensorClose(input1.grad, input2.grad) + self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) + self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) + + def testSyncBatchNormNormalTrain(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) + + def testSyncBatchNormNormalEval(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) + + def testSyncBatchNormSyncTrain(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) + + def testSyncBatchNormSyncEval(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) + + def testSyncBatchNorm2DSyncTrain(self): + bn = nn.BatchNorm2d(10) + sync_bn = SynchronizedBatchNorm2d(10) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/third_party/gim/gim/mit_semseg/lib/nn/modules/unittest.py b/third_party/gim/gim/mit_semseg/lib/nn/modules/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..0675c022e4ba85d38d1f813490f6740150909524 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/nn/modules/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest + +import numpy as np +from torch.autograd import Variable + + +def as_numpy(v): + if isinstance(v, Variable): + v = v.data + return v.cpu().numpy() + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): + npa, npb = as_numpy(a), as_numpy(b) + self.assertTrue( + np.allclose(npa, npb, atol=atol), + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + ) diff --git a/third_party/gim/gim/mit_semseg/lib/nn/parallel/__init__.py b/third_party/gim/gim/mit_semseg/lib/nn/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b52f49cc0755562218a460483cbf02514ddd773 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/nn/parallel/__init__.py @@ -0,0 +1 @@ +from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to diff --git a/third_party/gim/gim/mit_semseg/lib/nn/parallel/data_parallel.py b/third_party/gim/gim/mit_semseg/lib/nn/parallel/data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..376fc038919aa2a5bd696141e7bb6025d4981306 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/nn/parallel/data_parallel.py @@ -0,0 +1,112 @@ +# -*- coding: utf8 -*- + +import torch.cuda as cuda +import torch.nn as nn +import torch +import collections +from torch.nn.parallel._functions import Gather + + +__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] + + +def async_copy_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + v = obj.cuda(dev, non_blocking=True) + if main_stream is not None: + v.data.record_stream(main_stream) + return v + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] + else: + return obj + + +def dict_gather(outputs, target_device, dim=0): + """ + Gathers variables from different GPUs on a specified device + (-1 means the CPU), with dictionary support. + """ + def gather_map(outputs): + out = outputs[0] + if torch.is_tensor(out): + # MJY(20180330) HACK:: force nr_dims > 0 + if out.dim() == 0: + outputs = [o.unsqueeze(0) for o in outputs] + return Gather.apply(target_device, dim, *outputs) + elif out is None: + return None + elif isinstance(out, collections.Mapping): + return {k: gather_map([o[k] for o in outputs]) for k in out} + elif isinstance(out, collections.Sequence): + return type(out)(map(gather_map, zip(*outputs))) + return gather_map(outputs) + + +class DictGatherDataParallel(nn.DataParallel): + def gather(self, outputs, output_device): + return dict_gather(outputs, output_device, dim=self.dim) + + +class UserScatteredDataParallel(DictGatherDataParallel): + def scatter(self, inputs, kwargs, device_ids): + assert len(inputs) == 1 + inputs = inputs[0] + inputs = _async_copy_stream(inputs, device_ids) + inputs = [[i] for i in inputs] + assert len(kwargs) == 0 + kwargs = [{} for _ in range(len(inputs))] + + return inputs, kwargs + + +def user_scattered_collate(batch): + return batch + + +def _async_copy(inputs, device_ids): + nr_devs = len(device_ids) + assert type(inputs) in (tuple, list) + assert len(inputs) == nr_devs + + outputs = [] + for i, dev in zip(inputs, device_ids): + with cuda.device(dev): + outputs.append(async_copy_to(i, dev)) + + return tuple(outputs) + + +def _async_copy_stream(inputs, device_ids): + nr_devs = len(device_ids) + assert type(inputs) in (tuple, list) + assert len(inputs) == nr_devs + + outputs = [] + streams = [_get_stream(d) for d in device_ids] + for i, dev, stream in zip(inputs, device_ids, streams): + with cuda.device(dev): + main_stream = cuda.current_stream() + with cuda.stream(stream): + outputs.append(async_copy_to(i, dev, main_stream=main_stream)) + main_stream.wait_stream(stream) + + return outputs + + +"""Adapted from: torch/nn/parallel/_functions.py""" +# background streams used for copying +_streams = None + + +def _get_stream(device): + """Gets a background stream for copying between CPU and GPU""" + global _streams + if device == -1: + return None + if _streams is None: + _streams = [None] * cuda.device_count() + if _streams[device] is None: _streams[device] = cuda.Stream(device) + return _streams[device] diff --git a/third_party/gim/gim/mit_semseg/lib/utils/__init__.py b/third_party/gim/gim/mit_semseg/lib/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abe3cbe49477fe37d4fc16249de8a10f4fb4a013 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/utils/__init__.py @@ -0,0 +1 @@ +from .th import * diff --git a/third_party/gim/gim/mit_semseg/lib/utils/data/__init__.py b/third_party/gim/gim/mit_semseg/lib/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b008fb13c5e8a84b1b785056e8c4f5226dc976 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/utils/data/__init__.py @@ -0,0 +1,3 @@ + +from .dataset import Dataset, TensorDataset, ConcatDataset +from .dataloader import DataLoader diff --git a/third_party/gim/gim/mit_semseg/lib/utils/data/dataloader.py b/third_party/gim/gim/mit_semseg/lib/utils/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..039b9ec3645b2a4626ff47c221e372f32a6ad339 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/utils/data/dataloader.py @@ -0,0 +1,425 @@ +import torch +import torch.multiprocessing as multiprocessing +from torch._C import _set_worker_signal_handlers, \ + _remove_worker_pids, _error_if_any_worker_fails +try: + from torch._C import _set_worker_pids +except: + from torch._C import _update_worker_pids as _set_worker_pids +from .sampler import SequentialSampler, RandomSampler, BatchSampler +import signal +import collections +import re +import sys +import threading +import traceback +from torch._six import string_classes, int_classes +import numpy as np + +if sys.version_info[0] == 2: + import Queue as queue +else: + import queue + + +class ExceptionWrapper(object): + r"Wraps an exception plus traceback to communicate across threads" + + def __init__(self, exc_info): + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + + +_use_shared_memory = False +"""Whether to use shared memory in default_collate""" + + +def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): + global _use_shared_memory + _use_shared_memory = True + + # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal happened again already. + # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 + _set_worker_signal_handlers() + + torch.set_num_threads(1) + torch.manual_seed(seed) + np.random.seed(seed) + + if init_fn is not None: + init_fn(worker_id) + + while True: + r = index_queue.get() + if r is None: + break + idx, batch_indices = r + try: + samples = collate_fn([dataset[i] for i in batch_indices]) + except Exception: + data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + else: + data_queue.put((idx, samples)) + + +def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): + if pin_memory: + torch.cuda.set_device(device_id) + + while True: + try: + r = in_queue.get() + except Exception: + if done_event.is_set(): + return + raise + if r is None: + break + if isinstance(r[1], ExceptionWrapper): + out_queue.put(r) + continue + idx, batch = r + try: + if pin_memory: + batch = pin_memory_batch(batch) + except Exception: + out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + else: + out_queue.put((idx, batch)) + +numpy_type_map = { + 'float64': torch.DoubleTensor, + 'float32': torch.FloatTensor, + 'float16': torch.HalfTensor, + 'int64': torch.LongTensor, + 'int32': torch.IntTensor, + 'int16': torch.ShortTensor, + 'int8': torch.CharTensor, + 'uint8': torch.ByteTensor, +} + + +def default_collate(batch): + "Puts each data field into a tensor with outer dimension batch size" + + error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" + elem_type = type(batch[0]) + if torch.is_tensor(batch[0]): + out = None + if _use_shared_memory: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = batch[0].storage()._new_shared(numel) + out = batch[0].new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + elem = batch[0] + if elem_type.__name__ == 'ndarray': + # array of string classes and object + if re.search('[SaUO]', elem.dtype.str) is not None: + raise TypeError(error_msg.format(elem.dtype)) + + return torch.stack([torch.from_numpy(b) for b in batch], 0) + if elem.shape == (): # scalars + py_type = float if elem.dtype.name.startswith('float') else int + return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) + elif isinstance(batch[0], int_classes): + return torch.LongTensor(batch) + elif isinstance(batch[0], float): + return torch.DoubleTensor(batch) + elif isinstance(batch[0], string_classes): + return batch + elif isinstance(batch[0], collections.Mapping): + return {key: default_collate([d[key] for d in batch]) for key in batch[0]} + elif isinstance(batch[0], collections.Sequence): + transposed = zip(*batch) + return [default_collate(samples) for samples in transposed] + + raise TypeError((error_msg.format(type(batch[0])))) + + +def pin_memory_batch(batch): + if torch.is_tensor(batch): + return batch.pin_memory() + elif isinstance(batch, string_classes): + return batch + elif isinstance(batch, collections.Mapping): + return {k: pin_memory_batch(sample) for k, sample in batch.items()} + elif isinstance(batch, collections.Sequence): + return [pin_memory_batch(sample) for sample in batch] + else: + return batch + + +_SIGCHLD_handler_set = False +"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one +handler needs to be set for all DataLoaders in a process.""" + + +def _set_SIGCHLD_handler(): + # Windows doesn't support SIGCHLD handler + if sys.platform == 'win32': + return + # can't set signal in child threads + if not isinstance(threading.current_thread(), threading._MainThread): + return + global _SIGCHLD_handler_set + if _SIGCHLD_handler_set: + return + previous_handler = signal.getsignal(signal.SIGCHLD) + if not callable(previous_handler): + previous_handler = None + + def handler(signum, frame): + # This following call uses `waitid` with WNOHANG from C side. Therefore, + # Python can still get and update the process status successfully. + _error_if_any_worker_fails() + if previous_handler is not None: + previous_handler(signum, frame) + + signal.signal(signal.SIGCHLD, handler) + _SIGCHLD_handler_set = True + + +class DataLoaderIter(object): + "Iterates once over the DataLoader's dataset, as specified by the sampler" + + def __init__(self, loader): + self.dataset = loader.dataset + self.collate_fn = loader.collate_fn + self.batch_sampler = loader.batch_sampler + self.num_workers = loader.num_workers + self.pin_memory = loader.pin_memory and torch.cuda.is_available() + self.timeout = loader.timeout + self.done_event = threading.Event() + + self.sample_iter = iter(self.batch_sampler) + + if self.num_workers > 0: + self.worker_init_fn = loader.worker_init_fn + self.index_queue = multiprocessing.SimpleQueue() + self.worker_result_queue = multiprocessing.SimpleQueue() + self.batches_outstanding = 0 + self.worker_pids_set = False + self.shutdown = False + self.send_idx = 0 + self.rcvd_idx = 0 + self.reorder_dict = {} + + base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0] + self.workers = [ + multiprocessing.Process( + target=_worker_loop, + args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, + base_seed + i, self.worker_init_fn, i)) + for i in range(self.num_workers)] + + if self.pin_memory or self.timeout > 0: + self.data_queue = queue.Queue() + if self.pin_memory: + maybe_device_id = torch.cuda.current_device() + else: + # do not initialize cuda context if not necessary + maybe_device_id = None + self.worker_manager_thread = threading.Thread( + target=_worker_manager_loop, + args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, + maybe_device_id)) + self.worker_manager_thread.daemon = True + self.worker_manager_thread.start() + else: + self.data_queue = self.worker_result_queue + + for w in self.workers: + w.daemon = True # ensure that the worker exits on process exit + w.start() + + _set_worker_pids(id(self), tuple(w.pid for w in self.workers)) + _set_SIGCHLD_handler() + self.worker_pids_set = True + + # prime the prefetch loop + for _ in range(2 * self.num_workers): + self._put_indices() + + def __len__(self): + return len(self.batch_sampler) + + def _get_batch(self): + if self.timeout > 0: + try: + return self.data_queue.get(timeout=self.timeout) + except queue.Empty: + raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) + else: + return self.data_queue.get() + + def __next__(self): + if self.num_workers == 0: # same-process loading + indices = next(self.sample_iter) # may raise StopIteration + batch = self.collate_fn([self.dataset[i] for i in indices]) + if self.pin_memory: + batch = pin_memory_batch(batch) + return batch + + # check if the next sample has already been generated + if self.rcvd_idx in self.reorder_dict: + batch = self.reorder_dict.pop(self.rcvd_idx) + return self._process_next_batch(batch) + + if self.batches_outstanding == 0: + self._shutdown_workers() + raise StopIteration + + while True: + assert (not self.shutdown and self.batches_outstanding > 0) + idx, batch = self._get_batch() + self.batches_outstanding -= 1 + if idx != self.rcvd_idx: + # store out-of-order samples + self.reorder_dict[idx] = batch + continue + return self._process_next_batch(batch) + + next = __next__ # Python 2 compatibility + + def __iter__(self): + return self + + def _put_indices(self): + assert self.batches_outstanding < 2 * self.num_workers + indices = next(self.sample_iter, None) + if indices is None: + return + self.index_queue.put((self.send_idx, indices)) + self.batches_outstanding += 1 + self.send_idx += 1 + + def _process_next_batch(self, batch): + self.rcvd_idx += 1 + self._put_indices() + if isinstance(batch, ExceptionWrapper): + raise batch.exc_type(batch.exc_msg) + return batch + + def __getstate__(self): + # TODO: add limited pickling support for sharing an iterator + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("DataLoaderIterator cannot be pickled") + + def _shutdown_workers(self): + try: + if not self.shutdown: + self.shutdown = True + self.done_event.set() + # if worker_manager_thread is waiting to put + while not self.data_queue.empty(): + self.data_queue.get() + for _ in self.workers: + self.index_queue.put(None) + # done_event should be sufficient to exit worker_manager_thread, + # but be safe here and put another None + self.worker_result_queue.put(None) + finally: + # removes pids no matter what + if self.worker_pids_set: + _remove_worker_pids(id(self)) + self.worker_pids_set = False + + def __del__(self): + if self.num_workers > 0: + self._shutdown_workers() + + +class DataLoader(object): + """ + Data loader. Combines a dataset and a sampler, and provides + single- or multi-process iterators over the dataset. + + Arguments: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: 1). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: False). + sampler (Sampler, optional): defines the strategy to draw samples from + the dataset. If specified, ``shuffle`` must be False. + batch_sampler (Sampler, optional): like sampler, but returns a batch of + indices at a time. Mutually exclusive with batch_size, shuffle, + sampler, and drop_last. + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means that the data will be loaded in the main process. + (default: 0) + collate_fn (callable, optional): merges a list of samples to form a mini-batch. + pin_memory (bool, optional): If ``True``, the data loader will copy tensors + into CUDA pinned memory before returning them. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: False) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: 0) + worker_init_fn (callable, optional): If not None, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: None) + + .. note:: By default, each worker will have its PyTorch seed set to + ``base_seed + worker_id``, where ``base_seed`` is a long generated + by main process using its RNG. You may use ``torch.initial_seed()`` to access + this value in :attr:`worker_init_fn`, which can be used to set other seeds + (e.g. NumPy) before data loading. + + .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an + unpicklable object, e.g., a lambda function. + """ + + def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, + num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None): + self.dataset = dataset + self.batch_size = batch_size + self.num_workers = num_workers + self.collate_fn = collate_fn + self.pin_memory = pin_memory + self.drop_last = drop_last + self.timeout = timeout + self.worker_init_fn = worker_init_fn + + if timeout < 0: + raise ValueError('timeout option should be non-negative') + + if batch_sampler is not None: + if batch_size > 1 or shuffle or sampler is not None or drop_last: + raise ValueError('batch_sampler is mutually exclusive with ' + 'batch_size, shuffle, sampler, and drop_last') + + if sampler is not None and shuffle: + raise ValueError('sampler is mutually exclusive with shuffle') + + if self.num_workers < 0: + raise ValueError('num_workers cannot be negative; ' + 'use num_workers=0 to disable multiprocessing.') + + if batch_sampler is None: + if sampler is None: + if shuffle: + sampler = RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.sampler = sampler + self.batch_sampler = batch_sampler + + def __iter__(self): + return DataLoaderIter(self) + + def __len__(self): + return len(self.batch_sampler) diff --git a/third_party/gim/gim/mit_semseg/lib/utils/data/dataset.py b/third_party/gim/gim/mit_semseg/lib/utils/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..605aa877f7031a5cd2b98c0f831410aa80fddefa --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/utils/data/dataset.py @@ -0,0 +1,118 @@ +import bisect +import warnings + +from torch._utils import _accumulate +from torch import randperm + + +class Dataset(object): + """An abstract class representing a Dataset. + + All other datasets should subclass it. All subclasses should override + ``__len__``, that provides the size of the dataset, and ``__getitem__``, + supporting integer indexing in range from 0 to len(self) exclusive. + """ + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def __add__(self, other): + return ConcatDataset([self, other]) + + +class TensorDataset(Dataset): + """Dataset wrapping data and target tensors. + + Each sample will be retrieved by indexing both tensors along the first + dimension. + + Arguments: + data_tensor (Tensor): contains sample data. + target_tensor (Tensor): contains sample targets (labels). + """ + + def __init__(self, data_tensor, target_tensor): + assert data_tensor.size(0) == target_tensor.size(0) + self.data_tensor = data_tensor + self.target_tensor = target_tensor + + def __getitem__(self, index): + return self.data_tensor[index], self.target_tensor[index] + + def __len__(self): + return self.data_tensor.size(0) + + +class ConcatDataset(Dataset): + """ + Dataset to concatenate multiple datasets. + Purpose: useful to assemble different existing datasets, possibly + large-scale datasets as the concatenation operation is done in an + on-the-fly manner. + + Arguments: + datasets (iterable): List of datasets to be concatenated + """ + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets): + super(ConcatDataset, self).__init__() + assert len(datasets) > 0, 'datasets should not be an empty iterable' + self.datasets = list(datasets) + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + @property + def cummulative_sizes(self): + warnings.warn("cummulative_sizes attribute is renamed to " + "cumulative_sizes", DeprecationWarning, stacklevel=2) + return self.cumulative_sizes + + +class Subset(Dataset): + def __init__(self, dataset, indices): + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + return self.dataset[self.indices[idx]] + + def __len__(self): + return len(self.indices) + + +def random_split(dataset, lengths): + """ + Randomly split a dataset into non-overlapping new datasets of given lengths + ds + + Arguments: + dataset (Dataset): Dataset to be split + lengths (iterable): lengths of splits to be produced + """ + if sum(lengths) != len(dataset): + raise ValueError("Sum of input lengths does not equal the length of the input dataset!") + + indices = randperm(sum(lengths)) + return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] diff --git a/third_party/gim/gim/mit_semseg/lib/utils/data/distributed.py b/third_party/gim/gim/mit_semseg/lib/utils/data/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d890e28fd2b9e044bdd9494de4a43ad2471eed --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/utils/data/distributed.py @@ -0,0 +1,58 @@ +import math +import torch +from .sampler import Sampler +from torch.distributed import get_world_size, get_rank + + +class DistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None): + if num_replicas is None: + num_replicas = get_world_size() + if rank is None: + rank = get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = list(torch.randperm(len(self.dataset), generator=g)) + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset:offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/third_party/gim/gim/mit_semseg/lib/utils/data/sampler.py b/third_party/gim/gim/mit_semseg/lib/utils/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..62a9a43bd1d4c21fbdcb262db7da8d4fe27b26de --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/utils/data/sampler.py @@ -0,0 +1,131 @@ +import torch + + +class Sampler(object): + """Base class for all Samplers. + + Every Sampler subclass has to provide an __iter__ method, providing a way + to iterate over indices of dataset elements, and a __len__ method that + returns the length of the returned iterators. + """ + + def __init__(self, data_source): + pass + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +class SequentialSampler(Sampler): + """Samples elements sequentially, always in the same order. + + Arguments: + data_source (Dataset): dataset to sample from + """ + + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(range(len(self.data_source))) + + def __len__(self): + return len(self.data_source) + + +class RandomSampler(Sampler): + """Samples elements randomly, without replacement. + + Arguments: + data_source (Dataset): dataset to sample from + """ + + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(torch.randperm(len(self.data_source)).long()) + + def __len__(self): + return len(self.data_source) + + +class SubsetRandomSampler(Sampler): + """Samples elements randomly from a given list of indices, without replacement. + + Arguments: + indices (list): a list of indices + """ + + def __init__(self, indices): + self.indices = indices + + def __iter__(self): + return (self.indices[i] for i in torch.randperm(len(self.indices))) + + def __len__(self): + return len(self.indices) + + +class WeightedRandomSampler(Sampler): + """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). + + Arguments: + weights (list) : a list of weights, not necessary summing up to one + num_samples (int): number of samples to draw + replacement (bool): if ``True``, samples are drawn with replacement. + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + """ + + def __init__(self, weights, num_samples, replacement=True): + self.weights = torch.DoubleTensor(weights) + self.num_samples = num_samples + self.replacement = replacement + + def __iter__(self): + return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) + + def __len__(self): + return self.num_samples + + +class BatchSampler(object): + """Wraps another sampler to yield a mini-batch of indices. + + Args: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + + Example: + >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + def __init__(self, sampler, batch_size, drop_last): + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [] + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size diff --git a/third_party/gim/gim/mit_semseg/lib/utils/th.py b/third_party/gim/gim/mit_semseg/lib/utils/th.py new file mode 100644 index 0000000000000000000000000000000000000000..ca6ef9385e3b5c0a439579d3fd7aa73b5dc62758 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/lib/utils/th.py @@ -0,0 +1,41 @@ +import torch +from torch.autograd import Variable +import numpy as np +import collections + +__all__ = ['as_variable', 'as_numpy', 'mark_volatile'] + +def as_variable(obj): + if isinstance(obj, Variable): + return obj + if isinstance(obj, collections.Sequence): + return [as_variable(v) for v in obj] + elif isinstance(obj, collections.Mapping): + return {k: as_variable(v) for k, v in obj.items()} + else: + return Variable(obj) + +def as_numpy(obj): + if isinstance(obj, collections.Sequence): + return [as_numpy(v) for v in obj] + elif isinstance(obj, collections.Mapping): + return {k: as_numpy(v) for k, v in obj.items()} + elif isinstance(obj, Variable): + return obj.data.cpu().numpy() + elif torch.is_tensor(obj): + return obj.cpu().numpy() + else: + return np.array(obj) + +def mark_volatile(obj): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + obj.no_grad = True + return obj + elif isinstance(obj, collections.Mapping): + return {k: mark_volatile(o) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [mark_volatile(o) for o in obj] + else: + return obj diff --git a/third_party/gim/gim/mit_semseg/models/__init__.py b/third_party/gim/gim/mit_semseg/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76b40a0a36bc2976f185dbdc344c5a7c09b65920 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/models/__init__.py @@ -0,0 +1 @@ +from .models import ModelBuilder, SegmentationModule diff --git a/third_party/gim/gim/mit_semseg/models/hrnet.py b/third_party/gim/gim/mit_semseg/models/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..579f3c5e4979d5c3896a393e211925b6ff85c8e4 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/models/hrnet.py @@ -0,0 +1,445 @@ +""" +This HRNet implementation is modified from the following repository: +https://github.com/HRNet/HRNet-Semantic-Segmentation +""" + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +from .utils import load_url +from ..lib.nn import SynchronizedBatchNorm2d + +BatchNorm2d = SynchronizedBatchNorm2d +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +__all__ = ['hrnetv2'] + + +model_urls = { + 'hrnetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/hrnetv2_w48-imagenet.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=True) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=(height_output, width_output), + mode='bilinear', + align_corners=False) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HRNetV2(nn.Module): + def __init__(self, n_class, **kwargs): + super(HRNetV2, self).__init__() + extra = { + 'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (48, 96), 'FUSE_METHOD': 'SUM'}, + 'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (48, 96, 192), 'FUSE_METHOD': 'SUM'}, + 'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (48, 96, 192, 384), 'FUSE_METHOD': 'SUM'}, + 'FINAL_CONV_KERNEL': 1 + } + + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(Bottleneck, 64, 64, 4) + + self.stage2_cfg = extra['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = extra['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = extra['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + BatchNorm2d( + num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x, return_feature_maps=False): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate( + x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + x2 = F.interpolate( + x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + x3 = F.interpolate( + x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + + x = torch.cat([x[0], x1, x2, x3], 1) + + # x = self.last_layer(x) + return [x] + + +def hrnetv2(pretrained=False, **kwargs): + model = HRNetV2(n_class=1000, **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['hrnetv2']), strict=False) + + return model diff --git a/third_party/gim/gim/mit_semseg/models/mobilenet.py b/third_party/gim/gim/mit_semseg/models/mobilenet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0ddec4b1747dfe7b22ee61c78a4dd75187f645 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/models/mobilenet.py @@ -0,0 +1,154 @@ +""" +This MobileNetV2 implementation is modified from the following repository: +https://github.com/tonylins/pytorch-mobilenet-v2 +""" + +import torch.nn as nn +import math +from .utils import load_url +from ..lib.nn import SynchronizedBatchNorm2d + +BatchNorm2d = SynchronizedBatchNorm2d + + +__all__ = ['mobilenetv2'] + + +model_urls = { + 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar', +} + + +def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +def conv_1x1_bn(inp, oup): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = round(inp * expand_ratio) + self.use_res_connect = self.stride == 1 and inp == oup + + if expand_ratio == 1: + self.conv = nn.Sequential( + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + BatchNorm2d(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + BatchNorm2d(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, n_class=1000, input_size=224, width_mult=1.): + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + interverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + assert input_size % 32 == 0 + input_channel = int(input_channel * width_mult) + self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel + self.features = [conv_bn(3, input_channel, 2)] + # building inverted residual blocks + for t, c, n, s in interverted_residual_setting: + output_channel = int(c * width_mult) + for i in range(n): + if i == 0: + self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) + else: + self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) + input_channel = output_channel + # building last several layers + self.features.append(conv_1x1_bn(input_channel, self.last_channel)) + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.last_channel, n_class), + ) + + self._initialize_weights() + + def forward(self, x): + x = self.features(x) + x = x.mean(3).mean(2) + x = self.classifier(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +def mobilenetv2(pretrained=False, **kwargs): + """Constructs a MobileNet_V2 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = MobileNetV2(n_class=1000, **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False) + return model diff --git a/third_party/gim/gim/mit_semseg/models/models.py b/third_party/gim/gim/mit_semseg/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a624d08d776e26b353d41d6a48f9feaa1852ed --- /dev/null +++ b/third_party/gim/gim/mit_semseg/models/models.py @@ -0,0 +1,586 @@ +import torch +import torch.nn as nn +from . import resnet, resnext, mobilenet, hrnet +from ..lib.nn import SynchronizedBatchNorm2d +BatchNorm2d = SynchronizedBatchNorm2d + + +class SegmentationModuleBase(nn.Module): + def __init__(self): + super(SegmentationModuleBase, self).__init__() + + def pixel_acc(self, pred, label): + _, preds = torch.max(pred, dim=1) + valid = (label >= 0).long() + acc_sum = torch.sum(valid * (preds == label).long()) + pixel_sum = torch.sum(valid) + acc = acc_sum.float() / (pixel_sum.float() + 1e-10) + return acc + + +class SegmentationModule(SegmentationModuleBase): + def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None): + super(SegmentationModule, self).__init__() + self.encoder = net_enc + self.decoder = net_dec + self.crit = crit + self.deep_sup_scale = deep_sup_scale + + def forward(self, feed_dict, *, segSize=None): + # training + if segSize is None: + if self.deep_sup_scale is not None: # use deep supervision technique + (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) + else: + pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) + + loss = self.crit(pred, feed_dict['seg_label']) + if self.deep_sup_scale is not None: + loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label']) + loss = loss + loss_deepsup * self.deep_sup_scale + + acc = self.pixel_acc(pred, feed_dict['seg_label']) + return loss, acc + # inference + else: + pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize) + return pred + + +class ModelBuilder: + # custom weights initialization + @staticmethod + def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find('BatchNorm') != -1: + m.weight.data.fill_(1.) + m.bias.data.fill_(1e-4) + #elif classname.find('Linear') != -1: + # m.weight.data.normal_(0.0, 0.0001) + + @staticmethod + def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''): + pretrained = True if len(weights) == 0 else False + arch = arch.lower() + if arch == 'mobilenetv2dilated': + orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained) + net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8) + elif arch == 'resnet18': + orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet18dilated': + orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnet34': + raise NotImplementedError + orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet34dilated': + raise NotImplementedError + orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnet50': + orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet50dilated': + orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnet101': + orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet101dilated': + orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnext101': + orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained) + net_encoder = Resnet(orig_resnext) # we can still use class Resnet + elif arch == 'hrnetv2': + net_encoder = hrnet.__dict__['hrnetv2'](pretrained=pretrained) + else: + raise Exception('Architecture undefined!') + + # encoders are usually pretrained + # net_encoder.apply(ModelBuilder.weights_init) + if len(weights) > 0: + # print('Loading weights for net_encoder') + net_encoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), strict=False) + return net_encoder + + @staticmethod + def build_decoder(arch='ppm_deepsup', + fc_dim=512, num_class=150, + weights='', use_softmax=False): + arch = arch.lower() + if arch == 'c1_deepsup': + net_decoder = C1DeepSup( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax) + elif arch == 'c1': + net_decoder = C1( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax) + elif arch == 'ppm': + net_decoder = PPM( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax) + elif arch == 'ppm_deepsup': + net_decoder = PPMDeepsup( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax) + elif arch == 'upernet_lite': + net_decoder = UPerNet( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax, + fpn_dim=256) + elif arch == 'upernet': + net_decoder = UPerNet( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax, + fpn_dim=512) + else: + raise Exception('Architecture undefined!') + + net_decoder.apply(ModelBuilder.weights_init) + if len(weights) > 0: + # print('Loading weights for net_decoder') + net_decoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), strict=False) + return net_decoder + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1): + "3x3 convolution + BN + relu" + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=stride, padding=1, bias=False), + BatchNorm2d(out_planes), + nn.ReLU(inplace=True), + ) + + +class Resnet(nn.Module): + def __init__(self, orig_resnet): + super(Resnet, self).__init__() + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x); conv_out.append(x); + x = self.layer2(x); conv_out.append(x); + x = self.layer3(x); conv_out.append(x); + x = self.layer4(x); conv_out.append(x); + + if return_feature_maps: + return conv_out + return [x] + + +class ResnetDilated(nn.Module): + def __init__(self, orig_resnet, dilate_scale=8): + super(ResnetDilated, self).__init__() + from functools import partial + + if dilate_scale == 8: + orig_resnet.layer3.apply( + partial(self._nostride_dilate, dilate=2)) + orig_resnet.layer4.apply( + partial(self._nostride_dilate, dilate=4)) + elif dilate_scale == 16: + orig_resnet.layer4.apply( + partial(self._nostride_dilate, dilate=2)) + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + # the convolution with stride + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate//2, dilate//2) + m.padding = (dilate//2, dilate//2) + # other convoluions + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x); conv_out.append(x); + x = self.layer2(x); conv_out.append(x); + x = self.layer3(x); conv_out.append(x); + x = self.layer4(x); conv_out.append(x); + + if return_feature_maps: + return conv_out + return [x] + + +class MobileNetV2Dilated(nn.Module): + def __init__(self, orig_net, dilate_scale=8): + super(MobileNetV2Dilated, self).__init__() + from functools import partial + + # take pretrained mobilenet features + self.features = orig_net.features[:-1] + + self.total_idx = len(self.features) + self.down_idx = [2, 4, 7, 14] + + if dilate_scale == 8: + for i in range(self.down_idx[-2], self.down_idx[-1]): + self.features[i].apply( + partial(self._nostride_dilate, dilate=2) + ) + for i in range(self.down_idx[-1], self.total_idx): + self.features[i].apply( + partial(self._nostride_dilate, dilate=4) + ) + elif dilate_scale == 16: + for i in range(self.down_idx[-1], self.total_idx): + self.features[i].apply( + partial(self._nostride_dilate, dilate=2) + ) + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + # the convolution with stride + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate//2, dilate//2) + m.padding = (dilate//2, dilate//2) + # other convoluions + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def forward(self, x, return_feature_maps=False): + if return_feature_maps: + conv_out = [] + for i in range(self.total_idx): + x = self.features[i](x) + if i in self.down_idx: + conv_out.append(x) + conv_out.append(x) + return conv_out + + else: + return [self.features(x)] + + +# last conv, deep supervision +class C1DeepSup(nn.Module): + def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): + super(C1DeepSup, self).__init__() + self.use_softmax = use_softmax + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + x = self.cbr(conv5) + x = self.conv_last(x) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +# last conv +class C1(nn.Module): + def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): + super(C1, self).__init__() + self.use_softmax = use_softmax + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + x = self.cbr(conv5) + x = self.conv_last(x) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + else: + x = nn.functional.log_softmax(x, dim=1) + + return x + + +# pyramid pooling +class PPM(nn.Module): + def __init__(self, num_class=150, fc_dim=4096, + use_softmax=False, pool_scales=(1, 2, 3, 6)): + super(PPM, self).__init__() + self.use_softmax = use_softmax + + self.ppm = [] + for scale in pool_scales: + self.ppm.append(nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm = nn.ModuleList(self.ppm) + + self.conv_last = nn.Sequential( + nn.Conv2d(fc_dim+len(pool_scales)*512, 512, + kernel_size=3, padding=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True), + nn.Dropout2d(0.1), + nn.Conv2d(512, num_class, kernel_size=1) + ) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale in self.ppm: + ppm_out.append(nn.functional.interpolate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False)) + ppm_out = torch.cat(ppm_out, 1) + + x = self.conv_last(ppm_out) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + else: + x = nn.functional.log_softmax(x, dim=1) + return x + + +# pyramid pooling, deep supervision +class PPMDeepsup(nn.Module): + def __init__(self, num_class=150, fc_dim=4096, + use_softmax=False, pool_scales=(1, 2, 3, 6)): + super(PPMDeepsup, self).__init__() + self.use_softmax = use_softmax + + self.ppm = [] + for scale in pool_scales: + self.ppm.append(nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm = nn.ModuleList(self.ppm) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + self.conv_last = nn.Sequential( + nn.Conv2d(fc_dim+len(pool_scales)*512, 512, + kernel_size=3, padding=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True), + nn.Dropout2d(0.1), + nn.Conv2d(512, num_class, kernel_size=1) + ) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.dropout_deepsup = nn.Dropout2d(0.1) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale in self.ppm: + ppm_out.append(nn.functional.interpolate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False)) + ppm_out = torch.cat(ppm_out, 1) + + x = self.conv_last(ppm_out) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.dropout_deepsup(_) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +# upernet +class UPerNet(nn.Module): + def __init__(self, num_class=150, fc_dim=4096, + use_softmax=False, pool_scales=(1, 2, 3, 6), + fpn_inplanes=(256, 512, 1024, 2048), fpn_dim=256): + super(UPerNet, self).__init__() + self.use_softmax = use_softmax + + # PPM Module + self.ppm_pooling = [] + self.ppm_conv = [] + + for scale in pool_scales: + self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) + self.ppm_conv.append(nn.Sequential( + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm_pooling = nn.ModuleList(self.ppm_pooling) + self.ppm_conv = nn.ModuleList(self.ppm_conv) + self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1) + + # FPN Module + self.fpn_in = [] + for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer + self.fpn_in.append(nn.Sequential( + nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False), + BatchNorm2d(fpn_dim), + nn.ReLU(inplace=True) + )) + self.fpn_in = nn.ModuleList(self.fpn_in) + + self.fpn_out = [] + for i in range(len(fpn_inplanes) - 1): # skip the top layer + self.fpn_out.append(nn.Sequential( + conv3x3_bn_relu(fpn_dim, fpn_dim, 1), + )) + self.fpn_out = nn.ModuleList(self.fpn_out) + + self.conv_last = nn.Sequential( + conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1), + nn.Conv2d(fpn_dim, num_class, kernel_size=1) + ) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): + ppm_out.append(pool_conv(nn.functional.interpolate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False))) + ppm_out = torch.cat(ppm_out, 1) + f = self.ppm_last_conv(ppm_out) + + fpn_feature_list = [f] + for i in reversed(range(len(conv_out) - 1)): + conv_x = conv_out[i] + conv_x = self.fpn_in[i](conv_x) # lateral branch + + f = nn.functional.interpolate( + f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch + f = conv_x + f + + fpn_feature_list.append(self.fpn_out[i](f)) + + fpn_feature_list.reverse() # [P2 - P5] + output_size = fpn_feature_list[0].size()[2:] + fusion_list = [fpn_feature_list[0]] + for i in range(1, len(fpn_feature_list)): + fusion_list.append(nn.functional.interpolate( + fpn_feature_list[i], + output_size, + mode='bilinear', align_corners=False)) + fusion_out = torch.cat(fusion_list, 1) + x = self.conv_last(fusion_out) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + x = nn.functional.log_softmax(x, dim=1) + + return x diff --git a/third_party/gim/gim/mit_semseg/models/resnet.py b/third_party/gim/gim/mit_semseg/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b5cc981a925feba1db76bdb3e3b99e05472ab508 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/models/resnet.py @@ -0,0 +1,216 @@ +import torch.nn as nn +import math +from .utils import load_url +from ..lib.nn import SynchronizedBatchNorm2d +BatchNorm2d = SynchronizedBatchNorm2d + + +__all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon! + + +model_urls = { + 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', + 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', + 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 128 + super(ResNet, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = BatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = BatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet18'])) + return model + +''' +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet34'])) + return model +''' + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet50']), strict=False) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet101']), strict=False) + return model + +# def resnet152(pretrained=False, **kwargs): +# """Constructs a ResNet-152 model. +# +# Args: +# pretrained (bool): If True, returns a model pre-trained on ImageNet +# """ +# model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) +# if pretrained: +# model.load_state_dict(load_url(model_urls['resnet152'])) +# return model diff --git a/third_party/gim/gim/mit_semseg/models/resnext.py b/third_party/gim/gim/mit_semseg/models/resnext.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e260cd24dcdcee552e3ff0acac0c2ac7bd3adc --- /dev/null +++ b/third_party/gim/gim/mit_semseg/models/resnext.py @@ -0,0 +1,163 @@ +import torch.nn as nn +import math +from .utils import load_url +from ..lib.nn import SynchronizedBatchNorm2d +BatchNorm2d = SynchronizedBatchNorm2d + + +__all__ = ['ResNeXt', 'resnext101'] # support resnext 101 + + +model_urls = { + #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', + 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class GroupBottleneck(nn.Module): + expansion = 2 + + def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): + super(GroupBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) + self.bn3 = BatchNorm2d(planes * 2) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNeXt(nn.Module): + + def __init__(self, block, layers, groups=32, num_classes=1000): + self.inplanes = 128 + super(ResNeXt, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = BatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = BatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) + self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) + self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) + self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(1024 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, groups=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, groups, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=groups)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +''' +def resnext50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnext50']), strict=False) + return model +''' + + +def resnext101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on Places + """ + model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnext101']), strict=False) + return model + + +# def resnext152(pretrained=False, **kwargs): +# """Constructs a ResNeXt-152 model. +# +# Args: +# pretrained (bool): If True, returns a model pre-trained on Places +# """ +# model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) +# if pretrained: +# model.load_state_dict(load_url(model_urls['resnext152'])) +# return model diff --git a/third_party/gim/gim/mit_semseg/models/utils.py b/third_party/gim/gim/mit_semseg/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7301cbdbcc395adb110184f299fc47a3ce9a8716 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/models/utils.py @@ -0,0 +1,18 @@ +import sys +import os +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve +import torch + + +def load_url(url, model_dir='./pretrained', map_location=None): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + filename = url.split('/')[-1] + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, cached_file) + return torch.load(cached_file, map_location=map_location) diff --git a/third_party/gim/gim/mit_semseg/utils.py b/third_party/gim/gim/mit_semseg/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..600e91de91ba0fd93d29d0e03bd70652a65f3e92 --- /dev/null +++ b/third_party/gim/gim/mit_semseg/utils.py @@ -0,0 +1,200 @@ +import sys +import os +import logging +import re +import functools +import fnmatch +import numpy as np + + +def setup_logger(distributed_rank=0, filename="log.txt"): + logger = logging.getLogger("Logger") + logger.setLevel(logging.DEBUG) + # don't log results for the non-master process + if distributed_rank > 0: + return logger + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" + ch.setFormatter(logging.Formatter(fmt)) + logger.addHandler(ch) + + return logger + + +def find_recursive(root_dir, ext='.jpg'): + files = [] + for root, dirnames, filenames in os.walk(root_dir): + for filename in fnmatch.filter(filenames, '*' + ext): + files.append(os.path.join(root, filename)) + return files + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.initialized = False + self.val = None + self.avg = None + self.sum = None + self.count = None + + def initialize(self, val, weight): + self.val = val + self.avg = val + self.sum = val * weight + self.count = weight + self.initialized = True + + def update(self, val, weight=1): + if not self.initialized: + self.initialize(val, weight) + else: + self.add(val, weight) + + def add(self, val, weight): + self.val = val + self.sum += val * weight + self.count += weight + self.avg = self.sum / self.count + + def value(self): + return self.val + + def average(self): + return self.avg + + +def unique(ar, return_index=False, return_inverse=False, return_counts=False): + ar = np.asanyarray(ar).flatten() + + optional_indices = return_index or return_inverse + optional_returns = optional_indices or return_counts + + if ar.size == 0: + if not optional_returns: + ret = ar + else: + ret = (ar,) + if return_index: + ret += (np.empty(0, np.bool),) + if return_inverse: + ret += (np.empty(0, np.bool),) + if return_counts: + ret += (np.empty(0, np.intp),) + return ret + if optional_indices: + perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') + aux = ar[perm] + else: + ar.sort() + aux = ar + flag = np.concatenate(([True], aux[1:] != aux[:-1])) + + if not optional_returns: + ret = aux[flag] + else: + ret = (aux[flag],) + if return_index: + ret += (perm[flag],) + if return_inverse: + iflag = np.cumsum(flag) - 1 + inv_idx = np.empty(ar.shape, dtype=np.intp) + inv_idx[perm] = iflag + ret += (inv_idx,) + if return_counts: + idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) + ret += (np.diff(idx),) + return ret + + +def colorEncode(labelmap, colors, mode='RGB'): + labelmap = labelmap.astype('int') + labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), + dtype=np.uint8) + for label in unique(labelmap): + if label < 0: + continue + labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ + np.tile(colors[label], + (labelmap.shape[0], labelmap.shape[1], 1)) + + if mode == 'BGR': + return labelmap_rgb[:, :, ::-1] + else: + return labelmap_rgb + + +def accuracy(preds, label): + valid = (label >= 0) + acc_sum = (valid * (preds == label)).sum() + valid_sum = valid.sum() + acc = float(acc_sum) / (valid_sum + 1e-10) + return acc, valid_sum + + +def intersectionAndUnion(imPred, imLab, numClass): + imPred = np.asarray(imPred).copy() + imLab = np.asarray(imLab).copy() + + imPred += 1 + imLab += 1 + # Remove classes from unlabeled pixels in gt image. + # We should not penalize detections in unlabeled portions of the image. + imPred = imPred * (imLab > 0) + + # Compute area intersection: + intersection = imPred * (imPred == imLab) + (area_intersection, _) = np.histogram( + intersection, bins=numClass, range=(1, numClass)) + + # Compute area union: + (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) + (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) + area_union = area_pred + area_lab - area_intersection + + return (area_intersection, area_union) + + +class NotSupportedCliException(Exception): + pass + + +def process_range(xpu, inp): + start, end = map(int, inp) + if start > end: + end, start = start, end + return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1)) + + +REGEX = [ + (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]), + (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]), + (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'), + functools.partial(process_range, 'gpu')), + (re.compile(r'^(\d+)-(\d+)$'), + functools.partial(process_range, 'gpu')), +] + + +def parse_devices(input_devices): + + """Parse user's devices input str to standard format. + e.g. [gpu0, gpu1, ...] + + """ + ret = [] + for d in input_devices.split(','): + for regex, func in REGEX: + m = regex.match(d.lower().strip()) + if m: + tmp = func(m.groups()) + # prevent duplicate + for x in tmp: + if x not in ret: + ret.append(x) + break + else: + raise NotSupportedCliException( + 'Can not recognize device: "{}"'.format(d)) + return ret diff --git a/third_party/gim/sdasdada__init__.py b/third_party/gim/sdasdada__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5663b73e116701548d24dfeacb2d527bf79f5ff8 --- /dev/null +++ b/third_party/gim/sdasdada__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +# @Author : xuelun +# import sys +# from pathlib import Path +# sys.path.append(str(Path(__file__).parent)) +# from .gim import dkm +# from .gim import loftr +# from .gim import lightglue \ No newline at end of file