diff --git a/src/__pycache__/model_LN_prompt.cpython-310.pyc b/src/__pycache__/model_LN_prompt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4466d99206243f10cf8bf3f47413a419bacf977 Binary files /dev/null and b/src/__pycache__/model_LN_prompt.cpython-310.pyc differ diff --git a/src/__pycache__/options.cpython-310.pyc b/src/__pycache__/options.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d4e4a787d95fbcf4b2c5ab790402962177f9921 Binary files /dev/null and b/src/__pycache__/options.cpython-310.pyc differ diff --git a/src/dinov2/__init__.py b/src/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9 --- /dev/null +++ b/src/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/src/dinov2/__pycache__/__init__.cpython-310.pyc b/src/dinov2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cedc1cd8a667a08bffd877281eff30ac38195875 Binary files /dev/null and b/src/dinov2/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/dinov2/configs/__init__.py b/src/dinov2/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68e0830c62ea19649b6cd2361995f6df309d7640 --- /dev/null +++ b/src/dinov2/configs/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import pathlib + +from omegaconf import OmegaConf + + +def load_config(config_name: str): + config_filename = config_name + ".yaml" + return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename) + + +dinov2_default_config = load_config("ssl_default_config") + + +def load_and_merge_config(config_name: str): + default_config = OmegaConf.create(dinov2_default_config) + loaded_config = load_config(config_name) + return OmegaConf.merge(default_config, loaded_config) diff --git a/src/dinov2/configs/eval/vitb14_pretrain.yaml b/src/dinov2/configs/eval/vitb14_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..117d0f027ca26cd8ce6c010bb78d5a8fac42c70e --- /dev/null +++ b/src/dinov2/configs/eval/vitb14_pretrain.yaml @@ -0,0 +1,6 @@ +student: + arch: vit_base + patch_size: 14 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/dinov2/configs/eval/vitb14_reg4_pretrain.yaml b/src/dinov2/configs/eval/vitb14_reg4_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d53edc04a0761b4b35c147d63e04d55c90092c8f --- /dev/null +++ b/src/dinov2/configs/eval/vitb14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_base + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/dinov2/configs/eval/vitg14_pretrain.yaml b/src/dinov2/configs/eval/vitg14_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a96dd5b117b4d59ee210b65037821f1b3e3f16e3 --- /dev/null +++ b/src/dinov2/configs/eval/vitg14_pretrain.yaml @@ -0,0 +1,7 @@ +student: + arch: vit_giant2 + patch_size: 14 + ffn_layer: swiglufused +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/dinov2/configs/eval/vitg14_reg4_pretrain.yaml b/src/dinov2/configs/eval/vitg14_reg4_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15948f8589ea0a6e04717453eb88c18388e7f1b2 --- /dev/null +++ b/src/dinov2/configs/eval/vitg14_reg4_pretrain.yaml @@ -0,0 +1,10 @@ +student: + arch: vit_giant2 + patch_size: 14 + ffn_layer: swiglufused + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/dinov2/configs/eval/vitl14_pretrain.yaml b/src/dinov2/configs/eval/vitl14_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a984548bd034f762d455419d7193917fa462dd8 --- /dev/null +++ b/src/dinov2/configs/eval/vitl14_pretrain.yaml @@ -0,0 +1,6 @@ +student: + arch: vit_large + patch_size: 14 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/dinov2/configs/eval/vitl14_reg4_pretrain.yaml b/src/dinov2/configs/eval/vitl14_reg4_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e2bc4e7b24b1a64d0369a24927996d0f184e283 --- /dev/null +++ b/src/dinov2/configs/eval/vitl14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_large + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/dinov2/configs/eval/vits14_pretrain.yaml b/src/dinov2/configs/eval/vits14_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..afbdb4ba14f1c97130a25b579360f4d817cda495 --- /dev/null +++ b/src/dinov2/configs/eval/vits14_pretrain.yaml @@ -0,0 +1,6 @@ +student: + arch: vit_small + patch_size: 14 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/dinov2/configs/eval/vits14_reg4_pretrain.yaml b/src/dinov2/configs/eval/vits14_reg4_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d25fd638389bfba9220792302dc9dbf5d9a2406a --- /dev/null +++ b/src/dinov2/configs/eval/vits14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_small + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/src/dinov2/configs/ssl_default_config.yaml b/src/dinov2/configs/ssl_default_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ccaae1c3174b21bcaf6e803dc861492261e5abe1 --- /dev/null +++ b/src/dinov2/configs/ssl_default_config.yaml @@ -0,0 +1,118 @@ +MODEL: + WEIGHTS: '' +compute_precision: + grad_scaler: true + teacher: + backbone: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + dino_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + ibot_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + student: + backbone: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + dino_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp32 + buffer_dtype: fp32 + ibot_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp32 + buffer_dtype: fp32 +dino: + loss_weight: 1.0 + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_nlayers: 3 + head_hidden_dim: 2048 + koleo_loss_weight: 0.1 +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + separate_head: false + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_nlayers: 3 + head_hidden_dim: 2048 +train: + batch_size_per_gpu: 64 + dataset_path: ImageNet:split=TRAIN + output_dir: . + saveckp_freq: 20 + seed: 0 + num_workers: 10 + OFFICIAL_EPOCH_LENGTH: 1250 + cache_dataset: true + centering: "centering" # or "sinkhorn_knopp" +student: + arch: vit_large + patch_size: 16 + drop_path_rate: 0.3 + layerscale: 1.0e-05 + drop_path_uniform: true + pretrained_weights: '' + ffn_layer: "mlp" + block_chunks: 0 + qkv_bias: true + proj_bias: true + ffn_bias: true + num_register_tokens: 0 + interpolate_antialias: false + interpolate_offset: 0.1 +teacher: + momentum_teacher: 0.992 + final_momentum_teacher: 1 + warmup_teacher_temp: 0.04 + teacher_temp: 0.07 + warmup_teacher_temp_epochs: 30 +optim: + epochs: 100 + weight_decay: 0.04 + weight_decay_end: 0.4 + base_lr: 0.004 # learning rate for a batch size of 1024 + lr: 0. # will be set after applying scaling rule + warmup_epochs: 10 + min_lr: 1.0e-06 + clip_grad: 3.0 + freeze_last_layer_epochs: 1 + scaling_rule: sqrt_wrt_1024 + patch_embed_lr_mult: 0.2 + layerwise_decay: 0.9 + adamw_beta1: 0.9 + adamw_beta2: 0.999 +crops: + global_crops_scale: + - 0.32 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.05 + - 0.32 + global_crops_size: 224 + local_crops_size: 96 +evaluation: + eval_period_iterations: 12500 diff --git a/src/dinov2/configs/train/vitg14.yaml b/src/dinov2/configs/train/vitg14.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d05cf0d59e07ac6e4a2b0f9bdcb6131d7c508962 --- /dev/null +++ b/src/dinov2/configs/train/vitg14.yaml @@ -0,0 +1,26 @@ +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 +ibot: + separate_head: true + head_n_prototypes: 131072 +train: + batch_size_per_gpu: 12 + dataset_path: ImageNet22k + centering: sinkhorn_knopp +student: + arch: vit_giant2 + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 500 + weight_decay_end: 0.2 + base_lr: 2.0e-04 # learning rate for a batch size of 1024 + warmup_epochs: 80 + layerwise_decay: 1.0 +crops: + local_crops_size: 98 \ No newline at end of file diff --git a/src/dinov2/configs/train/vitl14.yaml b/src/dinov2/configs/train/vitl14.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d9b491dcc6a522c71328fc2933dd0501123c8f6b --- /dev/null +++ b/src/dinov2/configs/train/vitl14.yaml @@ -0,0 +1,26 @@ +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 +ibot: + separate_head: true + head_n_prototypes: 131072 +train: + batch_size_per_gpu: 32 + dataset_path: ImageNet22k + centering: sinkhorn_knopp +student: + arch: vit_large + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 500 + weight_decay_end: 0.2 + base_lr: 2.0e-04 # learning rate for a batch size of 1024 + warmup_epochs: 80 + layerwise_decay: 1.0 +crops: + local_crops_size: 98 \ No newline at end of file diff --git a/src/dinov2/configs/train/vitl16_short.yaml b/src/dinov2/configs/train/vitl16_short.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e7e72864c92175a1354142ac1d64da8070d1e5e --- /dev/null +++ b/src/dinov2/configs/train/vitl16_short.yaml @@ -0,0 +1,6 @@ +# this corresponds to the default config +train: + dataset_path: ImageNet:split=TRAIN + batch_size_per_gpu: 64 +student: + block_chunks: 4 diff --git a/src/dinov2/data/__init__.py b/src/dinov2/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2ded47ea63a7b184ff74a040e2c2c514cda273ef --- /dev/null +++ b/src/dinov2/data/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .adapters import DatasetWithEnumeratedTargets +from .loaders import make_data_loader, make_dataset, SamplerType +from .collate import collate_data_and_cast +from .masking import MaskingGenerator +from .augmentations import DataAugmentationDINO diff --git a/src/dinov2/data/adapters.py b/src/dinov2/data/adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..2097bad046fb1052267d5f2bb99c798045f00c92 --- /dev/null +++ b/src/dinov2/data/adapters.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from typing import Any, Tuple + +from torch.utils.data import Dataset + + +class DatasetWithEnumeratedTargets(Dataset): + def __init__(self, dataset): + self._dataset = dataset + + def get_image_data(self, index: int) -> bytes: + return self._dataset.get_image_data(index) + + def get_target(self, index: int) -> Tuple[Any, int]: + target = self._dataset.get_target(index) + return (index, target) + + def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]: + image, target = self._dataset[index] + target = index if target is None else target + return image, (index, target) + + def __len__(self) -> int: + return len(self._dataset) diff --git a/src/dinov2/data/augmentations.py b/src/dinov2/data/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..05b1eaa942c14f75b88d9e14732e141e8909b0a1 --- /dev/null +++ b/src/dinov2/data/augmentations.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from torchvision import transforms + +from .transforms import ( + GaussianBlur, + make_normalize_transform, +) + + +logger = logging.getLogger("dinov2") + + +class DataAugmentationDINO(object): + def __init__( + self, + global_crops_scale, + local_crops_scale, + local_crops_number, + global_crops_size=224, + local_crops_size=96, + ): + self.global_crops_scale = global_crops_scale + self.local_crops_scale = local_crops_scale + self.local_crops_number = local_crops_number + self.global_crops_size = global_crops_size + self.local_crops_size = local_crops_size + + logger.info("###################################") + logger.info("Using data augmentation parameters:") + logger.info(f"global_crops_scale: {global_crops_scale}") + logger.info(f"local_crops_scale: {local_crops_scale}") + logger.info(f"local_crops_number: {local_crops_number}") + logger.info(f"global_crops_size: {global_crops_size}") + logger.info(f"local_crops_size: {local_crops_size}") + logger.info("###################################") + + # random resized crop and flip + self.geometric_augmentation_global = transforms.Compose( + [ + transforms.RandomResizedCrop( + global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.RandomHorizontalFlip(p=0.5), + ] + ) + + self.geometric_augmentation_local = transforms.Compose( + [ + transforms.RandomResizedCrop( + local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.RandomHorizontalFlip(p=0.5), + ] + ) + + # color distorsions / blurring + color_jittering = transforms.Compose( + [ + transforms.RandomApply( + [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], + p=0.8, + ), + transforms.RandomGrayscale(p=0.2), + ] + ) + + global_transfo1_extra = GaussianBlur(p=1.0) + + global_transfo2_extra = transforms.Compose( + [ + GaussianBlur(p=0.1), + transforms.RandomSolarize(threshold=128, p=0.2), + ] + ) + + local_transfo_extra = GaussianBlur(p=0.5) + + # normalization + self.normalize = transforms.Compose( + [ + transforms.ToTensor(), + make_normalize_transform(), + ] + ) + + self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) + self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) + self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) + + def __call__(self, image): + output = {} + + # global crops: + im1_base = self.geometric_augmentation_global(image) + global_crop_1 = self.global_transfo1(im1_base) + + im2_base = self.geometric_augmentation_global(image) + global_crop_2 = self.global_transfo2(im2_base) + + output["global_crops"] = [global_crop_1, global_crop_2] + + # global crops for teacher: + output["global_crops_teacher"] = [global_crop_1, global_crop_2] + + # local crops: + local_crops = [ + self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) + ] + output["local_crops"] = local_crops + output["offsets"] = () + + return output diff --git a/src/dinov2/data/collate.py b/src/dinov2/data/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..b3e32f357a76e6f32162cee14cb6ae1665a4827a --- /dev/null +++ b/src/dinov2/data/collate.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import random + + +def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None): + # dtype = torch.half # TODO: Remove + + n_global_crops = len(samples_list[0][0]["global_crops"]) + n_local_crops = len(samples_list[0][0]["local_crops"]) + + collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list]) + + collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list]) + + B = len(collated_global_crops) + N = n_tokens + n_samples_masked = int(B * mask_probability) + probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1) + upperbound = 0 + masks_list = [] + for i in range(0, n_samples_masked): + prob_min = probs[i] + prob_max = probs[i + 1] + masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max))))) + upperbound += int(N * prob_max) + for i in range(n_samples_masked, B): + masks_list.append(torch.BoolTensor(mask_generator(0))) + + random.shuffle(masks_list) + + collated_masks = torch.stack(masks_list).flatten(1) + mask_indices_list = collated_masks.flatten().nonzero().flatten() + + masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks] + + return { + "collated_global_crops": collated_global_crops.to(dtype), + "collated_local_crops": collated_local_crops.to(dtype), + "collated_masks": collated_masks, + "mask_indices_list": mask_indices_list, + "masks_weight": masks_weight, + "upperbound": upperbound, + "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long), + } diff --git a/src/dinov2/data/datasets/__init__.py b/src/dinov2/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5550fdc5ce16269bc0c28795a389f0182e8bc6c8 --- /dev/null +++ b/src/dinov2/data/datasets/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .image_net import ImageNet +from .image_net_22k import ImageNet22k diff --git a/src/dinov2/data/datasets/decoders.py b/src/dinov2/data/datasets/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..3769f7750d94f7e0f7bce281ef3ff186970fc9cd --- /dev/null +++ b/src/dinov2/data/datasets/decoders.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from io import BytesIO +from typing import Any + +from PIL import Image + + +class Decoder: + def decode(self) -> Any: + raise NotImplementedError + + +class ImageDataDecoder(Decoder): + def __init__(self, image_data: bytes) -> None: + self._image_data = image_data + + def decode(self) -> Image: + f = BytesIO(self._image_data) + return Image.open(f).convert(mode="RGB") + + +class TargetDecoder(Decoder): + def __init__(self, target: Any): + self._target = target + + def decode(self) -> Any: + return self._target diff --git a/src/dinov2/data/datasets/extended.py b/src/dinov2/data/datasets/extended.py new file mode 100644 index 0000000000000000000000000000000000000000..f60b619a3c797823cccfc89e262cdb230f9188f0 --- /dev/null +++ b/src/dinov2/data/datasets/extended.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from typing import Any, Tuple + +from torchvision.datasets import VisionDataset + +from .decoders import TargetDecoder, ImageDataDecoder + + +class ExtendedVisionDataset(VisionDataset): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) # type: ignore + + def get_image_data(self, index: int) -> bytes: + raise NotImplementedError + + def get_target(self, index: int) -> Any: + raise NotImplementedError + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + try: + image_data = self.get_image_data(index) + image = ImageDataDecoder(image_data).decode() + except Exception as e: + raise RuntimeError(f"can not read image for sample {index}") from e + target = self.get_target(index) + target = TargetDecoder(target).decode() + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def __len__(self) -> int: + raise NotImplementedError diff --git a/src/dinov2/data/datasets/image_net.py b/src/dinov2/data/datasets/image_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8d08446147986c58360163e468896e994197c657 --- /dev/null +++ b/src/dinov2/data/datasets/image_net.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import csv +from enum import Enum +import logging +import os +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np + +from .extended import ExtendedVisionDataset + + +logger = logging.getLogger("dinov2") +_Target = int + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + TEST = "test" # NOTE: torchvision does not support the test split + + @property + def length(self) -> int: + split_lengths = { + _Split.TRAIN: 1_281_167, + _Split.VAL: 50_000, + _Split.TEST: 100_000, + } + return split_lengths[self] + + def get_dirname(self, class_id: Optional[str] = None) -> str: + return self.value if class_id is None else os.path.join(self.value, class_id) + + def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str: + dirname = self.get_dirname(class_id) + if self == _Split.TRAIN: + basename = f"{class_id}_{actual_index}" + else: # self in (_Split.VAL, _Split.TEST): + basename = f"ILSVRC2012_{self.value}_{actual_index:08d}" + return os.path.join(dirname, basename + ".JPEG") + + def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]: + assert self != _Split.TEST + dirname, filename = os.path.split(image_relpath) + class_id = os.path.split(dirname)[-1] + basename, _ = os.path.splitext(filename) + actual_index = int(basename.split("_")[-1]) + return class_id, actual_index + + +class ImageNet(ExtendedVisionDataset): + Target = Union[_Target] + Split = Union[_Split] + + def __init__( + self, + *, + split: "ImageNet.Split", + root: str, + extra: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transforms, transform, target_transform) + self._extra_root = extra + self._split = split + + self._entries = None + self._class_ids = None + self._class_names = None + + @property + def split(self) -> "ImageNet.Split": + return self._split + + def _get_extra_full_path(self, extra_path: str) -> str: + return os.path.join(self._extra_root, extra_path) + + def _load_extra(self, extra_path: str) -> np.ndarray: + extra_full_path = self._get_extra_full_path(extra_path) + return np.load(extra_full_path, mmap_mode="r") + + def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: + extra_full_path = self._get_extra_full_path(extra_path) + os.makedirs(self._extra_root, exist_ok=True) + np.save(extra_full_path, extra_array) + + @property + def _entries_path(self) -> str: + return f"entries-{self._split.value.upper()}.npy" + + @property + def _class_ids_path(self) -> str: + return f"class-ids-{self._split.value.upper()}.npy" + + @property + def _class_names_path(self) -> str: + return f"class-names-{self._split.value.upper()}.npy" + + def _get_entries(self) -> np.ndarray: + if self._entries is None: + self._entries = self._load_extra(self._entries_path) + assert self._entries is not None + return self._entries + + def _get_class_ids(self) -> np.ndarray: + if self._split == _Split.TEST: + assert False, "Class IDs are not available in TEST split" + if self._class_ids is None: + self._class_ids = self._load_extra(self._class_ids_path) + assert self._class_ids is not None + return self._class_ids + + def _get_class_names(self) -> np.ndarray: + if self._split == _Split.TEST: + assert False, "Class names are not available in TEST split" + if self._class_names is None: + self._class_names = self._load_extra(self._class_names_path) + assert self._class_names is not None + return self._class_names + + def find_class_id(self, class_index: int) -> str: + class_ids = self._get_class_ids() + return str(class_ids[class_index]) + + def find_class_name(self, class_index: int) -> str: + class_names = self._get_class_names() + return str(class_names[class_index]) + + def get_image_data(self, index: int) -> bytes: + entries = self._get_entries() + actual_index = entries[index]["actual_index"] + + class_id = self.get_class_id(index) + + image_relpath = self.split.get_image_relpath(actual_index, class_id) + image_full_path = os.path.join(self.root, image_relpath) + with open(image_full_path, mode="rb") as f: + image_data = f.read() + return image_data + + def get_target(self, index: int) -> Optional[Target]: + entries = self._get_entries() + class_index = entries[index]["class_index"] + return None if self.split == _Split.TEST else int(class_index) + + def get_targets(self) -> Optional[np.ndarray]: + entries = self._get_entries() + return None if self.split == _Split.TEST else entries["class_index"] + + def get_class_id(self, index: int) -> Optional[str]: + entries = self._get_entries() + class_id = entries[index]["class_id"] + return None if self.split == _Split.TEST else str(class_id) + + def get_class_name(self, index: int) -> Optional[str]: + entries = self._get_entries() + class_name = entries[index]["class_name"] + return None if self.split == _Split.TEST else str(class_name) + + def __len__(self) -> int: + entries = self._get_entries() + assert len(entries) == self.split.length + return len(entries) + + def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]: + labels_full_path = os.path.join(self.root, labels_path) + labels = [] + + try: + with open(labels_full_path, "r") as f: + reader = csv.reader(f) + for row in reader: + class_id, class_name = row + labels.append((class_id, class_name)) + except OSError as e: + raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e + + return labels + + def _dump_entries(self) -> None: + split = self.split + if split == ImageNet.Split.TEST: + dataset = None + sample_count = split.length + max_class_id_length, max_class_name_length = 0, 0 + else: + labels_path = "labels.txt" + logger.info(f'loading labels from "{labels_path}"') + labels = self._load_labels(labels_path) + + # NOTE: Using torchvision ImageFolder for consistency + from torchvision.datasets import ImageFolder + + dataset_root = os.path.join(self.root, split.get_dirname()) + dataset = ImageFolder(dataset_root) + sample_count = len(dataset) + max_class_id_length, max_class_name_length = -1, -1 + for sample in dataset.samples: + _, class_index = sample + class_id, class_name = labels[class_index] + max_class_id_length = max(len(class_id), max_class_id_length) + max_class_name_length = max(len(class_name), max_class_name_length) + + dtype = np.dtype( + [ + ("actual_index", " old_percent: + logger.info(f"creating entries: {percent}%") + old_percent = percent + + actual_index = index + 1 + class_index = np.uint32(-1) + class_id, class_name = "", "" + entries_array[index] = (actual_index, class_index, class_id, class_name) + else: + class_names = {class_id: class_name for class_id, class_name in labels} + + assert dataset + old_percent = -1 + for index in range(sample_count): + percent = 100 * (index + 1) // sample_count + if percent > old_percent: + logger.info(f"creating entries: {percent}%") + old_percent = percent + + image_full_path, class_index = dataset.samples[index] + image_relpath = os.path.relpath(image_full_path, self.root) + class_id, actual_index = split.parse_image_relpath(image_relpath) + class_name = class_names[class_id] + entries_array[index] = (actual_index, class_index, class_id, class_name) + + logger.info(f'saving entries to "{self._entries_path}"') + self._save_extra(entries_array, self._entries_path) + + def _dump_class_ids_and_names(self) -> None: + split = self.split + if split == ImageNet.Split.TEST: + return + + entries_array = self._load_extra(self._entries_path) + + max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1 + for entry in entries_array: + class_index, class_id, class_name = ( + entry["class_index"], + entry["class_id"], + entry["class_name"], + ) + max_class_index = max(int(class_index), max_class_index) + max_class_id_length = max(len(str(class_id)), max_class_id_length) + max_class_name_length = max(len(str(class_name)), max_class_name_length) + + class_count = max_class_index + 1 + class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}") + class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}") + for entry in entries_array: + class_index, class_id, class_name = ( + entry["class_index"], + entry["class_id"], + entry["class_name"], + ) + class_ids_array[class_index] = class_id + class_names_array[class_index] = class_name + + logger.info(f'saving class IDs to "{self._class_ids_path}"') + self._save_extra(class_ids_array, self._class_ids_path) + + logger.info(f'saving class names to "{self._class_names_path}"') + self._save_extra(class_names_array, self._class_names_path) + + def dump_extra(self) -> None: + self._dump_entries() + self._dump_class_ids_and_names() diff --git a/src/dinov2/data/datasets/image_net_22k.py b/src/dinov2/data/datasets/image_net_22k.py new file mode 100644 index 0000000000000000000000000000000000000000..52b36a2c664a7b72e30173b03b4e2aef1cd2fcd9 --- /dev/null +++ b/src/dinov2/data/datasets/image_net_22k.py @@ -0,0 +1,302 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from functools import lru_cache +from gzip import GzipFile +from io import BytesIO +from mmap import ACCESS_READ, mmap +import os +from typing import Any, Callable, List, Optional, Set, Tuple +import warnings + +import numpy as np + +from .extended import ExtendedVisionDataset + + +_Labels = int + +_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors + + +@dataclass +class _ClassEntry: + block_offset: int + maybe_filename: Optional[str] = None + + +@dataclass +class _Entry: + class_index: int # noqa: E701 + start_offset: int + end_offset: int + filename: str + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + + @property + def length(self) -> int: + return { + _Split.TRAIN: 11_797_647, + _Split.VAL: 561_050, + }[self] + + def entries_path(self): + return f"imagenet21kp_{self.value}.txt" + + +def _get_tarball_path(class_id: str) -> str: + return f"{class_id}.tar" + + +def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int): + @lru_cache(maxsize=mmap_cache_size) + def _mmap_tarball(class_id: str) -> mmap: + tarball_path = _get_tarball_path(class_id) + tarball_full_path = os.path.join(tarballs_root, tarball_path) + with open(tarball_full_path) as f: + return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ) + + return _mmap_tarball + + +class ImageNet22k(ExtendedVisionDataset): + _GZIPPED_INDICES: Set[int] = { + 841_545, + 1_304_131, + 2_437_921, + 2_672_079, + 2_795_676, + 2_969_786, + 6_902_965, + 6_903_550, + 6_903_628, + 7_432_557, + 7_432_589, + 7_813_809, + 8_329_633, + 10_296_990, + 10_417_652, + 10_492_265, + 10_598_078, + 10_782_398, + 10_902_612, + 11_203_736, + 11_342_890, + 11_397_596, + 11_589_762, + 11_705_103, + 12_936_875, + 13_289_782, + } + Labels = _Labels + + def __init__( + self, + *, + root: str, + extra: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE, + ) -> None: + super().__init__(root, transforms, transform, target_transform) + self._extra_root = extra + + entries_path = self._get_entries_path(root) + self._entries = self._load_extra(entries_path) + + class_ids_path = self._get_class_ids_path(root) + self._class_ids = self._load_extra(class_ids_path) + + self._gzipped_indices = ImageNet22k._GZIPPED_INDICES + self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size) + + def _get_entries_path(self, root: Optional[str] = None) -> str: + return "entries.npy" + + def _get_class_ids_path(self, root: Optional[str] = None) -> str: + return "class-ids.npy" + + def _find_class_ids(self, path: str) -> List[str]: + class_ids = [] + + with os.scandir(path) as entries: + for entry in entries: + root, ext = os.path.splitext(entry.name) + if ext != ".tar": + continue + class_ids.append(root) + + return sorted(class_ids) + + def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]: + root = self.get_root(root) + entries: List[_Entry] = [] + class_ids = self._find_class_ids(root) + + for class_index, class_id in enumerate(class_ids): + path = os.path.join(root, "blocks", f"{class_id}.log") + class_entries = [] + + try: + with open(path) as f: + for line in f: + line = line.rstrip() + block, filename = line.split(":") + block_offset = int(block[6:]) + filename = filename[1:] + + maybe_filename = None + if filename != "** Block of NULs **": + maybe_filename = filename + _, ext = os.path.splitext(filename) + # assert ext == ".JPEG" + + class_entry = _ClassEntry(block_offset, maybe_filename) + class_entries.append(class_entry) + except OSError as e: + raise RuntimeError(f'can not read blocks file "{path}"') from e + + assert class_entries[-1].maybe_filename is None + + for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]): + assert class_entry1.block_offset <= class_entry2.block_offset + start_offset = 512 * class_entry1.block_offset + end_offset = 512 * class_entry2.block_offset + assert class_entry1.maybe_filename is not None + filename = class_entry1.maybe_filename + entry = _Entry(class_index, start_offset, end_offset, filename) + # Skip invalid image files (PIL throws UnidentifiedImageError) + if filename == "n06470073_47249.JPEG": + continue + entries.append(entry) + + return entries, class_ids + + def _load_extra(self, extra_path: str) -> np.ndarray: + extra_root = self._extra_root + extra_full_path = os.path.join(extra_root, extra_path) + return np.load(extra_full_path, mmap_mode="r") + + def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: + extra_root = self._extra_root + extra_full_path = os.path.join(extra_root, extra_path) + os.makedirs(extra_root, exist_ok=True) + np.save(extra_full_path, extra_array) + + @property + def _tarballs_root(self) -> str: + return self.root + + def find_class_id(self, class_index: int) -> str: + return str(self._class_ids[class_index]) + + def get_image_data(self, index: int) -> bytes: + entry = self._entries[index] + class_id = entry["class_id"] + class_mmap = self._mmap_tarball(class_id) + + start_offset, end_offset = entry["start_offset"], entry["end_offset"] + try: + mapped_data = class_mmap[start_offset:end_offset] + data = mapped_data[512:] # Skip entry header block + + if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B): + assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}" + with GzipFile(fileobj=BytesIO(data)) as g: + data = g.read() + except Exception as e: + raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e + + return data + + def get_target(self, index: int) -> Any: + return int(self._entries[index]["class_index"]) + + def get_targets(self) -> np.ndarray: + return self._entries["class_index"] + + def get_class_id(self, index: int) -> str: + return str(self._entries[index]["class_id"]) + + def get_class_ids(self) -> np.ndarray: + return self._entries["class_id"] + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return super().__getitem__(index) + + def __len__(self) -> int: + return len(self._entries) + + def _dump_entries(self, *args, **kwargs) -> None: + entries, class_ids = self._load_entries_class_ids(*args, **kwargs) + + max_class_id_length, max_filename_length, max_class_index = -1, -1, -1 + for entry in entries: + class_id = class_ids[entry.class_index] + max_class_index = max(entry.class_index, max_class_index) + max_class_id_length = max(len(class_id), max_class_id_length) + max_filename_length = max(len(entry.filename), max_filename_length) + + dtype = np.dtype( + [ + ("class_index", " None: + entries_path = self._get_entries_path(*args, **kwargs) + entries_array = self._load_extra(entries_path) + + max_class_id_length, max_class_index = -1, -1 + for entry in entries_array: + class_index, class_id = entry["class_index"], entry["class_id"] + max_class_index = max(int(class_index), max_class_index) + max_class_id_length = max(len(str(class_id)), max_class_id_length) + + class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}") + for entry in entries_array: + class_index, class_id = entry["class_index"], entry["class_id"] + class_ids_array[class_index] = class_id + class_ids_path = self._get_class_ids_path(*args, **kwargs) + self._save_extra(class_ids_array, class_ids_path) + + def _dump_extra(self, *args, **kwargs) -> None: + self._dump_entries(*args, *kwargs) + self._dump_class_ids(*args, *kwargs) + + def dump_extra(self, root: Optional[str] = None) -> None: + return self._dump_extra(root) diff --git a/src/dinov2/data/loaders.py b/src/dinov2/data/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a2f0210efa0fa96be764665b5d6792191b1e72 --- /dev/null +++ b/src/dinov2/data/loaders.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +from enum import Enum +from typing import Any, Callable, List, Optional, TypeVar + +import torch +from torch.utils.data import Sampler + +from .datasets import ImageNet, ImageNet22k +from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler + + +logger = logging.getLogger("dinov2") + + +class SamplerType(Enum): + DISTRIBUTED = 0 + EPOCH = 1 + INFINITE = 2 + SHARDED_INFINITE = 3 + SHARDED_INFINITE_NEW = 4 + + +def _make_bool_str(b: bool) -> str: + return "yes" if b else "no" + + +def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): + def transform(sample): + image, target = sample + if image_transform is not None: + image = image_transform(image) + if target_transform is not None: + target = target_transform(target) + return image, target + + return transform + + +def _parse_dataset_str(dataset_str: str): + tokens = dataset_str.split(":") + + name = tokens[0] + kwargs = {} + + for token in tokens[1:]: + key, value = token.split("=") + assert key in ("root", "extra", "split") + kwargs[key] = value + + if name == "ImageNet": + class_ = ImageNet + if "split" in kwargs: + kwargs["split"] = ImageNet.Split[kwargs["split"]] + elif name == "ImageNet22k": + class_ = ImageNet22k + else: + raise ValueError(f'Unsupported dataset "{name}"') + + return class_, kwargs + + +def make_dataset( + *, + dataset_str: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, +): + """ + Creates a dataset with the specified parameters. + + Args: + dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN). + transform: A transform to apply to images. + target_transform: A transform to apply to targets. + + Returns: + The created dataset. + """ + logger.info(f'using dataset: "{dataset_str}"') + + class_, kwargs = _parse_dataset_str(dataset_str) + dataset = class_(transform=transform, target_transform=target_transform, **kwargs) + + logger.info(f"# of dataset samples: {len(dataset):,d}") + + # Aggregated datasets do not expose (yet) these attributes, so add them. + if not hasattr(dataset, "transform"): + setattr(dataset, "transform", transform) + if not hasattr(dataset, "target_transform"): + setattr(dataset, "target_transform", target_transform) + + return dataset + + +def _make_sampler( + *, + dataset, + type: Optional[SamplerType] = None, + shuffle: bool = False, + seed: int = 0, + size: int = -1, + advance: int = 0, +) -> Optional[Sampler]: + sample_count = len(dataset) + + if type == SamplerType.INFINITE: + logger.info("sampler: infinite") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + return InfiniteSampler( + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + advance=advance, + ) + elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW): + logger.info("sampler: sharded infinite") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + # TODO: Remove support for old shuffling + use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW + return ShardedInfiniteSampler( + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + advance=advance, + use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice, + ) + elif type == SamplerType.EPOCH: + logger.info("sampler: epoch") + if advance > 0: + raise NotImplementedError("sampler advance > 0 is not supported") + size = size if size > 0 else sample_count + logger.info(f"# of samples / epoch: {size:,d}") + return EpochSampler( + size=size, + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + ) + elif type == SamplerType.DISTRIBUTED: + logger.info("sampler: distributed") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + if advance > 0: + raise ValueError("sampler advance > 0 is invalid") + return torch.utils.data.DistributedSampler( + dataset=dataset, + shuffle=shuffle, + seed=seed, + drop_last=False, + ) + + logger.info("sampler: none") + return None + + +T = TypeVar("T") + + +def make_data_loader( + *, + dataset, + batch_size: int, + num_workers: int, + shuffle: bool = True, + seed: int = 0, + sampler_type: Optional[SamplerType] = SamplerType.INFINITE, + sampler_size: int = -1, + sampler_advance: int = 0, + drop_last: bool = True, + persistent_workers: bool = False, + collate_fn: Optional[Callable[[List[T]], Any]] = None, +): + """ + Creates a data loader with the specified parameters. + + Args: + dataset: A dataset (third party, LaViDa or WebDataset). + batch_size: The size of batches to generate. + num_workers: The number of workers to use. + shuffle: Whether to shuffle samples. + seed: The random seed to use. + sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None. + sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset. + sampler_advance: How many samples to skip (when applicable). + drop_last: Whether the last non-full batch of data should be dropped. + persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once. + collate_fn: Function that performs batch collation + """ + + sampler = _make_sampler( + dataset=dataset, + type=sampler_type, + shuffle=shuffle, + seed=seed, + size=sampler_size, + advance=sampler_advance, + ) + + logger.info("using PyTorch data loader") + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + drop_last=drop_last, + persistent_workers=persistent_workers, + collate_fn=collate_fn, + ) + + try: + logger.info(f"# of batches: {len(data_loader):,d}") + except TypeError: # data loader has no length + logger.info("infinite data loader") + return data_loader diff --git a/src/dinov2/data/masking.py b/src/dinov2/data/masking.py new file mode 100644 index 0000000000000000000000000000000000000000..ab12aa7bf138b916b16a9a2ed1a628a2759dbec6 --- /dev/null +++ b/src/dinov2/data/masking.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import random +import math +import numpy as np + + +class MaskingGenerator: + def __init__( + self, + input_size, + num_masking_patches=None, + min_num_patches=4, + max_num_patches=None, + min_aspect=0.3, + max_aspect=None, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.height, self.width = input_size + + self.num_patches = self.height * self.width + self.num_masking_patches = num_masking_patches + + self.min_num_patches = min_num_patches + self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches + + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def __repr__(self): + repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( + self.height, + self.width, + self.min_num_patches, + self.max_num_patches, + self.num_masking_patches, + self.log_aspect_ratio[0], + self.log_aspect_ratio[1], + ) + return repr_str + + def get_shape(self): + return self.height, self.width + + def _mask(self, mask, max_mask_patches): + delta = 0 + for _ in range(10): + target_area = random.uniform(self.min_num_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = random.randint(0, self.height - h) + left = random.randint(0, self.width - w) + + num_masked = mask[top : top + h, left : left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + return delta + + def __call__(self, num_masking_patches=0): + mask = np.zeros(shape=self.get_shape(), dtype=bool) + mask_count = 0 + while mask_count < num_masking_patches: + max_mask_patches = num_masking_patches - mask_count + max_mask_patches = min(max_mask_patches, self.max_num_patches) + + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + return mask diff --git a/src/dinov2/data/samplers.py b/src/dinov2/data/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..6562197d94652bb9a75a5fc722fcb2c65ca161be --- /dev/null +++ b/src/dinov2/data/samplers.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +from typing import Any, Optional +import warnings + +import numpy as np +import torch +from torch.utils.data.sampler import Sampler + +import dinov2.distributed as distributed + + +class EpochSampler(Sampler): + def __init__( + self, + *, + size: int, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + ): + self._size = size + self._sample_count = sample_count + self._shuffle = shuffle + self._seed = seed + self._start = distributed.get_global_rank() if start is None else start + self._step = distributed.get_global_size() if step is None else step + self._epoch = 0 + + def __iter__(self): + count = (self._size + self._sample_count - 1) // self._sample_count + tiled_indices = np.tile(np.arange(self._sample_count), count) + if self._shuffle: + seed = self._seed * self._epoch if self._seed != 0 else self._epoch + rng = np.random.default_rng(seed) + iterable = rng.choice(tiled_indices, self._size, replace=False) + else: + iterable = tiled_indices[: self._size] + + yield from itertools.islice(iterable, self._start, None, self._step) + + def __len__(self): + return (self._size - self._start + self._step - 1) // self._step + + def set_epoch(self, epoch): + self._epoch = epoch + + +def _get_numpy_dtype(size: int) -> Any: + return np.int32 if size <= 2**31 else np.int64 + + +def _get_torch_dtype(size: int) -> Any: + return torch.int32 if size <= 2**31 else torch.int64 + + +def _generate_randperm_indices(*, size: int, generator: torch.Generator): + """Generate the indices of a random permutation.""" + dtype = _get_torch_dtype(size) + # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921 + perm = torch.arange(size, dtype=dtype) + for i in range(size): + j = torch.randint(i, size, size=(1,), generator=generator).item() + + # Always swap even if no-op + value = perm[j].item() + perm[j] = perm[i].item() + perm[i] = value + yield value + + +class InfiniteSampler(Sampler): + def __init__( + self, + *, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + advance: int = 0, + ): + self._sample_count = sample_count + self._seed = seed + self._shuffle = shuffle + self._start = distributed.get_global_rank() if start is None else start + self._step = distributed.get_global_size() if step is None else step + self._advance = advance + + def __iter__(self): + if self._shuffle: + iterator = self._shuffled_iterator() + else: + iterator = self._iterator() + + yield from itertools.islice(iterator, self._advance, None) + + def _iterator(self): + assert not self._shuffle + + while True: + iterable = range(self._sample_count) + yield from itertools.islice(iterable, self._start, None, self._step) + + def _shuffled_iterator(self): + assert self._shuffle + + # Instantiate a generator here (rather than in the ctor) to keep the class + # picklable (requirement of mp.spawn) + generator = torch.Generator().manual_seed(self._seed) + + while True: + iterable = _generate_randperm_indices(size=self._sample_count, generator=generator) + yield from itertools.islice(iterable, self._start, None, self._step) + + +# The following function is somewhat equivalent to _new_shuffle_tensor_slice below, +# but avoids a full in-place random permutation generation. +def _shuffle_tensor_slice( + *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator +) -> np.ndarray: + stop = len(tensor) + count = stop // step + drop_count = stop - step * count + if drop_count: + warnings.warn(f"# of dropped samples: {drop_count}") + + dtype = _get_numpy_dtype(stop) + result = np.empty(count, dtype=dtype) + + for i in range(count): + j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0 + + result[i] = result[j] + result[j] = tensor[start + i * step].item() + + return result + + +def _new_shuffle_tensor_slice( + *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator +) -> np.ndarray: + stop = len(tensor) + count = stop // step + dtype = torch.int64 # Needed for using randperm result as indices + count = stop // step + drop_count = stop - step * count + if drop_count: + warnings.warn(f"# of dropped samples: {drop_count}") + indices = torch.randperm(count, dtype=dtype, generator=generator) + return tensor[start::step][indices].numpy() + + +def _make_seed(seed: int, start: int, iter_count: int) -> int: + # NOTE: Tried a few variants (including iter_count << 32), this one worked best. + return seed + start + (iter_count << 24) + + +class ShardedInfiniteSampler(Sampler): + def __init__( + self, + *, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + advance: int = 0, + use_new_shuffle_tensor_slice: bool = False, + ): + self._sample_count = sample_count + self._seed = seed + self._shuffle = shuffle + self._start = distributed.get_global_rank() if start is None else start + self._step = distributed.get_global_size() if step is None else step + self._advance = advance + self._iter_count = 0 + self._shuffle_tensor_slice_fn = ( + _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice + ) + + def __iter__(self): + iter_count = self._advance // self._sample_count + if iter_count > 0: + self._advance -= iter_count * self._sample_count + self._iter_count += iter_count + + if self._shuffle: + iterator = self._shuffled_iterator() + else: + iterator = self._iterator() + + yield from itertools.islice(iterator, self._advance, None) + + def _iterator(self): + assert not self._shuffle + + while True: + iterable = range(self._sample_count) + yield from itertools.islice(iterable, self._start, None, self._step) + + def _shuffled_iterator(self): + assert self._shuffle + + # Instantiate a generator here (rather than in the ctor) to be keep the class + # picklable (requirement of mp.spawn) + generator = torch.Generator() + + # Always shuffle everything first + generator.manual_seed(self._seed) + dtype = _get_torch_dtype(self._sample_count) + perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator) + + while True: + # Re-seed on each iteration to allow skipping whole permutations + seed = _make_seed(self._seed, self._start, self._iter_count) + generator.manual_seed(seed) + + iterable = self._shuffle_tensor_slice_fn( + tensor=perm, start=self._start, step=self._step, generator=generator + ) + yield from iterable + self._iter_count += 1 diff --git a/src/dinov2/data/transforms.py b/src/dinov2/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5f252b50c54d58f160528c9f2b00fad47103c7 --- /dev/null +++ b/src/dinov2/data/transforms.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from typing import Sequence + +import torch +from torchvision import transforms + + +class GaussianBlur(transforms.RandomApply): + """ + Apply Gaussian Blur to the PIL image. + """ + + def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): + # NOTE: torchvision is applying 1 - probability to return the original image + keep_p = 1 - p + transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) + super().__init__(transforms=[transform], p=keep_p) + + +class MaybeToTensor(transforms.ToTensor): + """ + Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. + Returns: + Tensor: Converted image. + """ + if isinstance(pic, torch.Tensor): + return pic + return super().__call__(pic) + + +# Use timm's names +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + + +def make_normalize_transform( + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> transforms.Normalize: + return transforms.Normalize(mean=mean, std=std) + + +# This roughly matches torchvision's preset for classification training: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 +def make_classification_train_transform( + *, + crop_size: int = 224, + interpolation=transforms.InterpolationMode.BICUBIC, + hflip_prob: float = 0.5, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +): + transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + if hflip_prob > 0.0: + transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) + transforms_list.extend( + [ + MaybeToTensor(), + make_normalize_transform(mean=mean, std=std), + ] + ) + return transforms.Compose(transforms_list) + + +# This matches (roughly) torchvision's preset for classification evaluation: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 +def make_classification_eval_transform( + *, + resize_size: int = 256, + interpolation=transforms.InterpolationMode.BICUBIC, + crop_size: int = 224, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> transforms.Compose: + transforms_list = [ + transforms.Resize(resize_size, interpolation=interpolation), + transforms.CenterCrop(crop_size), + MaybeToTensor(), + make_normalize_transform(mean=mean, std=std), + ] + return transforms.Compose(transforms_list) diff --git a/src/dinov2/distributed/__init__.py b/src/dinov2/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23226f4536bf5acf4ffac242e9903d92863b246d --- /dev/null +++ b/src/dinov2/distributed/__init__.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +import random +import re +import socket +from typing import Dict, List + +import torch +import torch.distributed as dist + +_LOCAL_RANK = -1 +_LOCAL_WORLD_SIZE = -1 + + +def is_enabled() -> bool: + """ + Returns: + True if distributed training is enabled + """ + return dist.is_available() and dist.is_initialized() + + +def get_global_size() -> int: + """ + Returns: + The number of processes in the process group + """ + return dist.get_world_size() if is_enabled() else 1 + + +def get_global_rank() -> int: + """ + Returns: + The rank of the current process within the global process group. + """ + return dist.get_rank() if is_enabled() else 0 + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not is_enabled(): + return 0 + assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE + return _LOCAL_RANK + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not is_enabled(): + return 1 + assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE + return _LOCAL_WORLD_SIZE + + +def is_main_process() -> bool: + """ + Returns: + True if the current process is the main one. + """ + return get_global_rank() == 0 + + +def _restrict_print_to_main_process() -> None: + """ + This function disables printing when not in the main process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_main_process() or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def _get_master_port(seed: int = 0) -> int: + MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) + + master_port_str = os.environ.get("MASTER_PORT") + if master_port_str is None: + rng = random.Random(seed) + return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) + + return int(master_port_str) + + +def _get_available_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # A "" host address means INADDR_ANY i.e. binding to all interfaces. + # Note this is not compatible with IPv6. + s.bind(("", 0)) + port = s.getsockname()[1] + return port + + +_TORCH_DISTRIBUTED_ENV_VARS = ( + "MASTER_ADDR", + "MASTER_PORT", + "RANK", + "WORLD_SIZE", + "LOCAL_RANK", + "LOCAL_WORLD_SIZE", +) + + +def _collect_env_vars() -> Dict[str, str]: + return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ} + + +def _is_slurm_job_process() -> bool: + return "SLURM_JOB_ID" in os.environ + + +def _parse_slurm_node_list(s: str) -> List[str]: + nodes = [] + # Extract "hostname", "hostname[1-2,3,4-5]," substrings + p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?") + for m in p.finditer(s): + prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)] + for suffix in suffixes.split(","): + span = suffix.split("-") + if len(span) == 1: + nodes.append(prefix + suffix) + else: + width = len(span[0]) + start, end = int(span[0]), int(span[1]) + 1 + nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)]) + return nodes + + +def _check_env_variable(key: str, new_value: str): + # Only check for difference with preset environment variables + if key in os.environ and os.environ[key] != new_value: + raise RuntimeError(f"Cannot export environment variables as {key} is already set") + + +class _TorchDistributedEnvironment: + def __init__(self): + self.master_addr = "127.0.0.1" + self.master_port = 0 + self.rank = -1 + self.world_size = -1 + self.local_rank = -1 + self.local_world_size = -1 + + if _is_slurm_job_process(): + return self._set_from_slurm_env() + + env_vars = _collect_env_vars() + if not env_vars: + # Environment is not set + pass + elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS): + # Environment is fully set + return self._set_from_preset_env() + else: + # Environment is partially set + collected_env_vars = ", ".join(env_vars.keys()) + raise RuntimeError(f"Partially set environment: {collected_env_vars}") + + if torch.cuda.device_count() > 0: + return self._set_from_local() + + raise RuntimeError("Can't initialize PyTorch distributed environment") + + # Slurm job created with sbatch, submitit, etc... + def _set_from_slurm_env(self): + # logger.info("Initialization from Slurm environment") + job_id = int(os.environ["SLURM_JOB_ID"]) + node_count = int(os.environ["SLURM_JOB_NUM_NODES"]) + nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"]) + assert len(nodes) == node_count + + self.master_addr = nodes[0] + self.master_port = _get_master_port(seed=job_id) + self.rank = int(os.environ["SLURM_PROCID"]) + self.world_size = int(os.environ["SLURM_NTASKS"]) + assert self.rank < self.world_size + self.local_rank = int(os.environ["SLURM_LOCALID"]) + self.local_world_size = self.world_size // node_count + assert self.local_rank < self.local_world_size + + # Single node job with preset environment (i.e. torchrun) + def _set_from_preset_env(self): + # logger.info("Initialization from preset environment") + self.master_addr = os.environ["MASTER_ADDR"] + self.master_port = os.environ["MASTER_PORT"] + self.rank = int(os.environ["RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + assert self.rank < self.world_size + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + assert self.local_rank < self.local_world_size + + # Single node and GPU job (i.e. local script run) + def _set_from_local(self): + # logger.info("Initialization from local") + self.master_addr = "127.0.0.1" + self.master_port = _get_available_port() + self.rank = 0 + self.world_size = 1 + self.local_rank = 0 + self.local_world_size = 1 + + def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment": + # See the "Environment variable initialization" section from + # https://pytorch.org/docs/stable/distributed.html for the complete list of + # environment variables required for the env:// initialization method. + env_vars = { + "MASTER_ADDR": self.master_addr, + "MASTER_PORT": str(self.master_port), + "RANK": str(self.rank), + "WORLD_SIZE": str(self.world_size), + "LOCAL_RANK": str(self.local_rank), + "LOCAL_WORLD_SIZE": str(self.local_world_size), + } + if not overwrite: + for k, v in env_vars.items(): + _check_env_variable(k, v) + + os.environ.update(env_vars) + return self + + +def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False): + """Enable distributed mode + + Args: + set_cuda_current_device: If True, call torch.cuda.set_device() to set the + current PyTorch CUDA device to the one matching the local rank. + overwrite: If True, overwrites already set variables. Else fails. + """ + + global _LOCAL_RANK, _LOCAL_WORLD_SIZE + if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0: + raise RuntimeError("Distributed mode has already been enabled") + torch_env = _TorchDistributedEnvironment() + torch_env.export(overwrite=overwrite) + + if set_cuda_current_device: + torch.cuda.set_device(torch_env.local_rank) + + if allow_nccl_timeout: + # This allows to use torch distributed timeout in a NCCL backend + key, value = "NCCL_ASYNC_ERROR_HANDLING", "1" + if not overwrite: + _check_env_variable(key, value) + os.environ[key] = value + + dist.init_process_group(backend="nccl") + dist.barrier() + + # Finalize setup + _LOCAL_RANK = torch_env.local_rank + _LOCAL_WORLD_SIZE = torch_env.local_world_size + _restrict_print_to_main_process() diff --git a/src/dinov2/eval/__init__.py b/src/dinov2/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/src/dinov2/eval/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/dinov2/eval/depth/__init__.py b/src/dinov2/eval/depth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/src/dinov2/eval/depth/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/dinov2/eval/depth/models/__init__.py b/src/dinov2/eval/depth/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5825181dc2189424b5c58d245b36919cbc5b2e --- /dev/null +++ b/src/dinov2/eval/depth/models/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .backbones import * # noqa: F403 +from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss +from .decode_heads import * # noqa: F403 +from .depther import * # noqa: F403 +from .losses import * # noqa: F403 diff --git a/src/dinov2/eval/depth/models/backbones/__init__.py b/src/dinov2/eval/depth/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..520d75bc6e064b9d64487293604ac1bda6e2b6f7 --- /dev/null +++ b/src/dinov2/eval/depth/models/backbones/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .vision_transformer import DinoVisionTransformer diff --git a/src/dinov2/eval/depth/models/backbones/vision_transformer.py b/src/dinov2/eval/depth/models/backbones/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..69bda46fd69eb7dabb8f5b60e6fa459fdc21aeab --- /dev/null +++ b/src/dinov2/eval/depth/models/backbones/vision_transformer.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.runner import BaseModule + +from ..builder import BACKBONES + + +@BACKBONES.register_module() +class DinoVisionTransformer(BaseModule): + """Vision Transformer.""" + + def __init__(self, *args, **kwargs): + super().__init__() diff --git a/src/dinov2/eval/depth/models/builder.py b/src/dinov2/eval/depth/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..c152643435308afcff60b07cd68ea979fe1d90cb --- /dev/null +++ b/src/dinov2/eval/depth/models/builder.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +from mmcv.cnn import MODELS as MMCV_MODELS +from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION +from mmcv.utils import Registry + +MODELS = Registry("models", parent=MMCV_MODELS) +ATTENTION = Registry("attention", parent=MMCV_ATTENTION) + + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +DEPTHER = MODELS + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_depther(cfg, train_cfg=None, test_cfg=None): + """Build depther.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning) + assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field " + assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field " + return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/src/dinov2/eval/depth/models/decode_heads/__init__.py b/src/dinov2/eval/depth/models/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd0f0754a5b01d7622c1f26bf3f60daea19da4e8 --- /dev/null +++ b/src/dinov2/eval/depth/models/decode_heads/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dpt_head import DPTHead +from .linear_head import BNHead diff --git a/src/dinov2/eval/depth/models/decode_heads/decode_head.py b/src/dinov2/eval/depth/models/decode_heads/decode_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c867a3ec687090b280d90bb86aee435320acda --- /dev/null +++ b/src/dinov2/eval/depth/models/decode_heads/decode_head.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy +from abc import ABCMeta, abstractmethod + +import mmcv +import numpy as np +import torch +import torch.nn as nn +from mmcv.runner import BaseModule, auto_fp16, force_fp32 + +from ...ops import resize +from ..builder import build_loss + + +class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta): + """Base class for BaseDecodeHead. + + Args: + in_channels (List): Input channels. + channels (int): Channels after modules, before conv_depth. + conv_cfg (dict|None): Config of conv layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + loss_decode (dict): Config of decode loss. + Default: dict(type='SigLoss'). + sampler (dict|None): The config of depth map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + min_depth (int): Min depth in dataset setting. + Default: 1e-3. + max_depth (int): Max depth in dataset setting. + Default: None. + norm_cfg (dict|None): Config of norm layers. + Default: None. + classify (bool): Whether predict depth in a cls.-reg. manner. + Default: False. + n_bins (int): The number of bins used in cls. step. + Default: 256. + bins_strategy (str): The discrete strategy used in cls. step. + Default: 'UD'. + norm_strategy (str): The norm strategy on cls. probability + distribution. Default: 'linear' + scale_up (str): Whether predict depth in a scale-up manner. + Default: False. + """ + + def __init__( + self, + in_channels, + channels=96, + conv_cfg=None, + act_cfg=dict(type="ReLU"), + loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10), + sampler=None, + align_corners=False, + min_depth=1e-3, + max_depth=None, + norm_cfg=None, + classify=False, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + scale_up=False, + ): + super(DepthBaseDecodeHead, self).__init__() + + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.act_cfg = act_cfg + if isinstance(loss_decode, dict): + self.loss_decode = build_loss(loss_decode) + elif isinstance(loss_decode, (list, tuple)): + self.loss_decode = nn.ModuleList() + for loss in loss_decode: + self.loss_decode.append(build_loss(loss)) + self.align_corners = align_corners + self.min_depth = min_depth + self.max_depth = max_depth + self.norm_cfg = norm_cfg + self.classify = classify + self.n_bins = n_bins + self.scale_up = scale_up + + if self.classify: + assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" + assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" + + self.bins_strategy = bins_strategy + self.norm_strategy = norm_strategy + self.softmax = nn.Softmax(dim=1) + self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) + else: + self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) + + self.fp16_enabled = False + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def extra_repr(self): + """Extra repr.""" + s = f"align_corners={self.align_corners}" + return s + + @auto_fp16() + @abstractmethod + def forward(self, inputs, img_metas): + """Placeholder of forward function.""" + pass + + def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): GT depth + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + depth_pred = self.forward(inputs, img_metas) + losses = self.losses(depth_pred, depth_gt) + + log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) + losses.update(**log_imgs) + + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Forward function for testing. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output depth map. + """ + return self.forward(inputs, img_metas) + + def depth_pred(self, feat): + """Prediction each pixel.""" + if self.classify: + logit = self.conv_depth(feat) + + if self.bins_strategy == "UD": + bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + elif self.bins_strategy == "SID": + bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + + # following Adabins, default linear + if self.norm_strategy == "linear": + logit = torch.relu(logit) + eps = 0.1 + logit = logit + eps + logit = logit / logit.sum(dim=1, keepdim=True) + elif self.norm_strategy == "softmax": + logit = torch.softmax(logit, dim=1) + elif self.norm_strategy == "sigmoid": + logit = torch.sigmoid(logit) + logit = logit / logit.sum(dim=1, keepdim=True) + + output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) + + else: + if self.scale_up: + output = self.sigmoid(self.conv_depth(feat)) * self.max_depth + else: + output = self.relu(self.conv_depth(feat)) + self.min_depth + return output + + @force_fp32(apply_to=("depth_pred",)) + def losses(self, depth_pred, depth_gt): + """Compute depth loss.""" + loss = dict() + depth_pred = resize( + input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False + ) + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) + else: + loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) + return loss + + def log_images(self, img_path, depth_pred, depth_gt, img_meta): + show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) + show_img = show_img.numpy().astype(np.float32) + show_img = mmcv.imdenormalize( + show_img, + img_meta["img_norm_cfg"]["mean"], + img_meta["img_norm_cfg"]["std"], + img_meta["img_norm_cfg"]["to_rgb"], + ) + show_img = np.clip(show_img, 0, 255) + show_img = show_img.astype(np.uint8) + show_img = show_img[:, :, ::-1] + show_img = show_img.transpose(0, 2, 1) + show_img = show_img.transpose(1, 0, 2) + + depth_pred = depth_pred / torch.max(depth_pred) + depth_gt = depth_gt / torch.max(depth_gt) + + depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) + depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) + + return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} diff --git a/src/dinov2/eval/depth/models/decode_heads/dpt_head.py b/src/dinov2/eval/depth/models/decode_heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c6c6d9470d78e1d944cc505f97865f026a9458d3 --- /dev/null +++ b/src/dinov2/eval/depth/models/decode_heads/dpt_head.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, Linear, build_activation_layer +from mmcv.runner import BaseModule + +from ...ops import resize +from ..builder import HEADS +from .decode_head import DepthBaseDecodeHead + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, align_corners=False): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + return x + + +class HeadDepth(nn.Module): + def __init__(self, features): + super(HeadDepth, self).__init__() + self.head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, x): + x = self.head(x) + return x + + +class ReassembleBlocks(BaseModule): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__( + self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None + ): + super(ReassembleBlocks, self).__init__(init_cfg) + + assert readout_type in ["ignore", "add", "project"] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList( + [ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_cfg=None, + ) + for out_channel in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + if self.readout_type == "project": + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU"))) + ) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == "project": + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == "add": + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(BaseModule): + """ResidualConvUnit, pre-activate residual unit. + Args: + in_channels (int): number of channels in the input feature map. + act_cfg (dict): dictionary to construct and config activation layer. + norm_cfg (dict): dictionary to construct and config norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None): + super(PreActResidualConvUnit, self).__init__(init_cfg) + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=("act", "conv", "norm"), + ) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=False, + order=("act", "conv", "norm"), + ) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(BaseModule): + """FeatureFusionBlock, merge feature map from different stages. + Args: + in_channels (int): Input channels. + act_cfg (dict): The activation config for ResidualConvUnit. + norm_cfg (dict): Config dict for normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + init_cfg (dict, optional): Initialization config dict. Default: None. + """ + + def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None): + super(FeatureFusionBlock, self).__init__(init_cfg) + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) + x = self.project(x) + return x + + +@HEADS.register_module() +class DPTHead(DepthBaseDecodeHead): + """Vision Transformers for Dense Prediction. + This head is implemented of `DPT `_. + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + """ + + def __init__( + self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type="ignore", + patch_size=16, + expand_channels=False, + **kwargs + ): + super(DPTHead, self).__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + self.conv_depth = HeadDepth(self.channels) + + def forward(self, inputs, img_metas): + assert len(inputs) == self.num_reassemble_blocks + x = [inp for inp in inputs] + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.depth_pred(out) + return out diff --git a/src/dinov2/eval/depth/models/decode_heads/linear_head.py b/src/dinov2/eval/depth/models/decode_heads/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3da1436f6a3f0bcc389d74ed86d44d455d2f7a87 --- /dev/null +++ b/src/dinov2/eval/depth/models/decode_heads/linear_head.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from ...ops import resize +from ..builder import HEADS +from .decode_head import DepthBaseDecodeHead + + +@HEADS.register_module() +class BNHead(DepthBaseDecodeHead): + """Just a batchnorm.""" + + def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): + super().__init__(**kwargs) + self.input_transform = input_transform + self.in_index = in_index + self.upsample = upsample + # self.bn = nn.SyncBatchNorm(self.in_channels) + if self.classify: + self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) + else: + self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + + if "concat" in self.input_transform: + inputs = [inputs[i] for i in self.in_index] + if "resize" in self.input_transform: + inputs = [ + resize( + input=x, + size=[s * self.upsample for s in inputs[0].shape[2:]], + mode="bilinear", + align_corners=self.align_corners, + ) + for x in inputs + ] + inputs = torch.cat(inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _forward_feature(self, inputs, img_metas=None, **kwargs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # accept lists (for cls token) + inputs = list(inputs) + for i, x in enumerate(inputs): + if len(x) == 2: + x, cls_token = x[0], x[1] + if len(x.shape) == 2: + x = x[:, :, None, None] + cls_token = cls_token[:, :, None, None].expand_as(x) + inputs[i] = torch.cat((x, cls_token), 1) + else: + x = x[0] + if len(x.shape) == 2: + x = x[:, :, None, None] + inputs[i] = x + x = self._transform_inputs(inputs) + # feats = self.bn(x) + return x + + def forward(self, inputs, img_metas=None, **kwargs): + """Forward function.""" + output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) + output = self.depth_pred(output) + + return output diff --git a/src/dinov2/eval/depth/models/depther/__init__.py b/src/dinov2/eval/depth/models/depther/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..be99743bf6c773d05f2b74524116e368c0cfcba0 --- /dev/null +++ b/src/dinov2/eval/depth/models/depther/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .base import BaseDepther +from .encoder_decoder import DepthEncoderDecoder diff --git a/src/dinov2/eval/depth/models/depther/base.py b/src/dinov2/eval/depth/models/depther/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e133a825a888167f90d95d67803609d6cac7ff55 --- /dev/null +++ b/src/dinov2/eval/depth/models/depther/base.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from abc import ABCMeta, abstractmethod +from collections import OrderedDict + +import torch +import torch.distributed as dist +from mmcv.runner import BaseModule, auto_fp16 + + +class BaseDepther(BaseModule, metaclass=ABCMeta): + """Base class for depther.""" + + def __init__(self, init_cfg=None): + super(BaseDepther, self).__init__(init_cfg) + self.fp16_enabled = False + + @property + def with_neck(self): + """bool: whether the depther has neck""" + return hasattr(self, "neck") and self.neck is not None + + @property + def with_auxiliary_head(self): + """bool: whether the depther has auxiliary head""" + return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None + + @property + def with_decode_head(self): + """bool: whether the depther has decode head""" + return hasattr(self, "decode_head") and self.decode_head is not None + + @abstractmethod + def extract_feat(self, imgs): + """Placeholder for extract features from images.""" + pass + + @abstractmethod + def encode_decode(self, img, img_metas): + """Placeholder for encode images with backbone and decode into a + semantic depth map of the same size as input.""" + pass + + @abstractmethod + def forward_train(self, imgs, img_metas, **kwargs): + """Placeholder for Forward function for training.""" + pass + + @abstractmethod + def simple_test(self, img, img_meta, **kwargs): + """Placeholder for single image test.""" + pass + + @abstractmethod + def aug_test(self, imgs, img_metas, **kwargs): + """Placeholder for augmentation test.""" + pass + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: + if not isinstance(var, list): + raise TypeError(f"{name} must be a list, but got " f"{type(var)}") + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})") + # all images in the same aug batch all of the same ori_shape and pad + # shape + for img_meta in img_metas: + ori_shapes = [_["ori_shape"] for _ in img_meta] + assert all(shape == ori_shapes[0] for shape in ori_shapes) + img_shapes = [_["img_shape"] for _ in img_meta] + assert all(shape == img_shapes[0] for shape in img_shapes) + pad_shapes = [_["pad_shape"] for _ in img_meta] + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + @auto_fp16(apply_to=("img",)) + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + + # split losses and images + real_losses = {} + log_imgs = {} + for k, v in losses.items(): + if "img" in k: + log_imgs[k] = v + else: + real_losses[k] = v + + loss, log_vars = self._parse_losses(real_losses) + + outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs) + + return outputs + + def val_step(self, data_batch, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + output = self(**data_batch, **kwargs) + return output + + @staticmethod + def _parse_losses(losses): + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError(f"{loss_name} is not a tensor or list of tensors") + + loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) + + log_vars["loss"] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars diff --git a/src/dinov2/eval/depth/models/depther/encoder_decoder.py b/src/dinov2/eval/depth/models/depther/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0ec2dd314fdf8ccf4414d81afb95326b7dc0c9 --- /dev/null +++ b/src/dinov2/eval/depth/models/depther/encoder_decoder.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +from ...models import builder +from ...models.builder import DEPTHER +from ...ops import resize +from .base import BaseDepther + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs + + +@DEPTHER.register_module() +class DepthEncoderDecoder(BaseDepther): + """Encoder Decoder depther. + + EncoderDecoder typically consists of backbone, (neck) and decode_head. + """ + + def __init__(self, backbone, decode_head, neck=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): + super(DepthEncoderDecoder, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get("pretrained") is None, "both backbone and depther set pretrained weight" + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + self._init_decode_head(decode_head) + + if neck is not None: + self.neck = builder.build_neck(neck) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas, rescale=True, size=None): + """Encode images with backbone and decode into a depth estimation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + # crop the pred depth to the certain range. + out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) + if rescale: + if size is None: + if img_metas is not None: + size = img_metas[0]["ori_shape"][:2] + else: + size = img.shape[2:] + out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs) + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return depth_pred + + def forward_dummy(self, img): + """Dummy forward function.""" + depth = self.encode_decode(img, None) + + return depth + + def forward_train(self, img, img_metas, depth_gt, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): Depth gt + used if the architecture supports depth estimation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + # the last of x saves the info from neck + loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) + + losses.update(loss_decode) + + return losses + + def whole_inference(self, img, img_meta, rescale, size=None): + """Inference with full image.""" + depth_pred = self.encode_decode(img, img_meta, rescale, size=size) + + return depth_pred + + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, 1, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + depth_pred = self.encode_decode(crop_img, img_meta, rescale) + preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + return preds + + def inference(self, img, img_meta, rescale, size=None): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output depth map. + """ + + assert self.test_cfg.mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if self.test_cfg.mode == "slide": + depth_pred = self.slide_inference(img, img_meta, rescale) + else: + depth_pred = self.whole_inference(img, img_meta, rescale, size=size) + output = depth_pred + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + depth_pred = self.inference(img, img_meta, rescale) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + depth_pred = depth_pred.unsqueeze(0) + return depth_pred + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented depth logit inplace + depth_pred = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:]) + depth_pred += cur_depth_pred + depth_pred /= len(imgs) + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred diff --git a/src/dinov2/eval/depth/models/losses/__init__.py b/src/dinov2/eval/depth/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f86242e342776da2e0acc61150d15a8d58ff1e0 --- /dev/null +++ b/src/dinov2/eval/depth/models/losses/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .gradientloss import GradientLoss +from .sigloss import SigLoss diff --git a/src/dinov2/eval/depth/models/losses/gradientloss.py b/src/dinov2/eval/depth/models/losses/gradientloss.py new file mode 100644 index 0000000000000000000000000000000000000000..1599878a6b70cdff4f8467e1e875f0d13ea89eca --- /dev/null +++ b/src/dinov2/eval/depth/models/losses/gradientloss.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from ...models.builder import LOSSES + + +@LOSSES.register_module() +class GradientLoss(nn.Module): + """GradientLoss. + + Adapted from https://www.cs.cornell.edu/projects/megadepth/ + + Args: + valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True. + loss_weight (float): Weight of the loss. Default: 1.0. + max_depth (int): When filtering invalid gt, set a max threshold. Default: None. + """ + + def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"): + super(GradientLoss, self).__init__() + self.valid_mask = valid_mask + self.loss_weight = loss_weight + self.max_depth = max_depth + self.loss_name = loss_name + + self.eps = 0.001 # avoid grad explode + + def gradientloss(self, input, target): + input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)] + target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)] + + gradient_loss = 0 + for input, target in zip(input_downscaled, target_downscaled): + if self.valid_mask: + mask = target > 0 + if self.max_depth is not None: + mask = torch.logical_and(target > 0, target <= self.max_depth) + N = torch.sum(mask) + else: + mask = torch.ones_like(target) + N = input.numel() + input_log = torch.log(input + self.eps) + target_log = torch.log(target + self.eps) + log_d_diff = input_log - target_log + + log_d_diff = torch.mul(log_d_diff, mask) + + v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :]) + v_mask = torch.mul(mask[0:-2, :], mask[2:, :]) + v_gradient = torch.mul(v_gradient, v_mask) + + h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:]) + h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:]) + h_gradient = torch.mul(h_gradient, h_mask) + + gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N + + return gradient_loss + + def forward(self, depth_pred, depth_gt): + """Forward function.""" + + gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt) + return gradient_loss diff --git a/src/dinov2/eval/depth/models/losses/sigloss.py b/src/dinov2/eval/depth/models/losses/sigloss.py new file mode 100644 index 0000000000000000000000000000000000000000..e12fad3e6151e4b975dd055193fdaec0206d4a14 --- /dev/null +++ b/src/dinov2/eval/depth/models/losses/sigloss.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from ...models.builder import LOSSES + + +@LOSSES.register_module() +class SigLoss(nn.Module): + """SigLoss. + + This follows `AdaBins `_. + + Args: + valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True. + loss_weight (float): Weight of the loss. Default: 1.0. + max_depth (int): When filtering invalid gt, set a max threshold. Default: None. + warm_up (bool): A simple warm up stage to help convergence. Default: False. + warm_iter (int): The number of warm up stage. Default: 100. + """ + + def __init__( + self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss" + ): + super(SigLoss, self).__init__() + self.valid_mask = valid_mask + self.loss_weight = loss_weight + self.max_depth = max_depth + self.loss_name = loss_name + + self.eps = 0.001 # avoid grad explode + + # HACK: a hack implementation for warmup sigloss + self.warm_up = warm_up + self.warm_iter = warm_iter + self.warm_up_counter = 0 + + def sigloss(self, input, target): + if self.valid_mask: + valid_mask = target > 0 + if self.max_depth is not None: + valid_mask = torch.logical_and(target > 0, target <= self.max_depth) + input = input[valid_mask] + target = target[valid_mask] + + if self.warm_up: + if self.warm_up_counter < self.warm_iter: + g = torch.log(input + self.eps) - torch.log(target + self.eps) + g = 0.15 * torch.pow(torch.mean(g), 2) + self.warm_up_counter += 1 + return torch.sqrt(g) + + g = torch.log(input + self.eps) - torch.log(target + self.eps) + Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2) + return torch.sqrt(Dg) + + def forward(self, depth_pred, depth_gt): + """Forward function.""" + + loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt) + return loss_depth diff --git a/src/dinov2/eval/depth/ops/__init__.py b/src/dinov2/eval/depth/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78181c29581a281b5f42cf12078636aaeb43b5a5 --- /dev/null +++ b/src/dinov2/eval/depth/ops/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .wrappers import resize diff --git a/src/dinov2/eval/depth/ops/wrappers.py b/src/dinov2/eval/depth/ops/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..15880ee0cb7652d4b41c489b927bf6a156b40e5e --- /dev/null +++ b/src/dinov2/eval/depth/ops/wrappers.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch.nn.functional as F + + +def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/src/dinov2/eval/knn.py b/src/dinov2/eval/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a4845da1313a6db6b8345bb9a98230fcd24acf --- /dev/null +++ b/src/dinov2/eval/knn.py @@ -0,0 +1,404 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +from functools import partial +import json +import logging +import os +import sys +from typing import List, Optional + +import torch +from torch.nn.functional import one_hot, softmax + +import dinov2.distributed as distributed +from dinov2.data import SamplerType, make_data_loader, make_dataset +from dinov2.data.transforms import make_classification_eval_transform +from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric +from dinov2.eval.setup import get_args_parser as get_setup_args_parser +from dinov2.eval.setup import setup_and_build_model +from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--nb_knn", + nargs="+", + type=int, + help="Number of NN to use. 20 is usually working the best.", + ) + parser.add_argument( + "--temperature", + type=float, + help="Temperature used in the voting coefficient", + ) + parser.add_argument( + "--gather-on-cpu", + action="store_true", + help="Whether to gather the train features on cpu, slower" + "but useful to avoid OOM for large datasets (e.g. ImageNet22k).", + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch size.", + ) + parser.add_argument( + "--n-per-class-list", + nargs="+", + type=int, + help="Number to take per class", + ) + parser.add_argument( + "--n-tries", + type=int, + help="Number of tries", + ) + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + nb_knn=[10, 20, 100, 200], + temperature=0.07, + batch_size=256, + n_per_class_list=[-1], + n_tries=1, + ) + return parser + + +class KnnModule(torch.nn.Module): + """ + Gets knn of test features from all processes on a chunk of the train features + + Each rank gets a chunk of the train features as well as a chunk of the test features. + In `compute_neighbors`, for each rank one after the other, its chunk of test features + is sent to all devices, partial knns are computed with each chunk of train features + then collated back on the original device. + """ + + def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000): + super().__init__() + + self.global_rank = distributed.get_global_rank() + self.global_size = distributed.get_global_size() + + self.device = device + self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device) + self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device) + + self.nb_knn = nb_knn + self.max_k = max(self.nb_knn) + self.T = T + self.num_classes = num_classes + + def _get_knn_sims_and_labels(self, similarity, train_labels): + topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True) + neighbors_labels = torch.gather(train_labels, 1, indices) + return topk_sims, neighbors_labels + + def _similarity_for_rank(self, features_rank, source_rank): + # Send the features from `source_rank` to all ranks + broadcast_shape = torch.tensor(features_rank.shape).to(self.device) + torch.distributed.broadcast(broadcast_shape, source_rank) + + broadcasted = features_rank + if self.global_rank != source_rank: + broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device) + torch.distributed.broadcast(broadcasted, source_rank) + + # Compute the neighbors for `source_rank` among `train_features_rank_T` + similarity_rank = torch.mm(broadcasted, self.train_features_rank_T) + candidate_labels = self.candidates.expand(len(similarity_rank), -1) + return self._get_knn_sims_and_labels(similarity_rank, candidate_labels) + + def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank): + # Gather all neighbors for `target_rank` + topk_sims_rank = retrieved_rank = None + if self.global_rank == target_rank: + topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)] + retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)] + + torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank) + torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank) + + if self.global_rank == target_rank: + # Perform a second top-k on the k * global_size retrieved neighbors + topk_sims_rank = torch.cat(topk_sims_rank, dim=1) + retrieved_rank = torch.cat(retrieved_rank, dim=1) + results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank) + return results + return None + + def compute_neighbors(self, features_rank): + for rank in range(self.global_size): + topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank) + results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank) + if results is not None: + topk_sims_rank, neighbors_labels_rank = results + return topk_sims_rank, neighbors_labels_rank + + def forward(self, features_rank): + """ + Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k` + """ + assert all(k <= self.max_k for k in self.nb_knn) + + topk_sims, neighbors_labels = self.compute_neighbors(features_rank) + batch_size = neighbors_labels.shape[0] + topk_sims_transform = softmax(topk_sims / self.T, 1) + matmul = torch.mul( + one_hot(neighbors_labels, num_classes=self.num_classes), + topk_sims_transform.view(batch_size, -1, 1), + ) + probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn} + return probas_for_k + + +class DictKeysModule(torch.nn.Module): + def __init__(self, keys): + super().__init__() + self.keys = keys + + def forward(self, features_dict, targets): + for k in self.keys: + features_dict = features_dict[k] + return {"preds": features_dict, "target": targets} + + +def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels): + modules = {} + mapping = create_class_indices_mapping(train_labels) + for npc in n_per_class_list: + if npc < 0: # Only one try needed when using the full data + full_module = module( + train_features=train_features, + train_labels=train_labels, + nb_knn=nb_knn, + ) + modules["full"] = ModuleDictWithForward({"1": full_module}) + continue + all_tries = {} + for t in range(n_tries): + final_indices = filter_train(mapping, npc, seed=t) + k_list = list(set(nb_knn + [npc])) + k_list = sorted([el for el in k_list if el <= npc]) + all_tries[str(t)] = module( + train_features=train_features[final_indices], + train_labels=train_labels[final_indices], + nb_knn=k_list, + ) + modules[f"{npc} per class"] = ModuleDictWithForward(all_tries) + + return ModuleDictWithForward(modules) + + +def filter_train(mapping, n_per_class, seed): + torch.manual_seed(seed) + final_indices = [] + for k in mapping.keys(): + index = torch.randperm(len(mapping[k]))[:n_per_class] + final_indices.append(mapping[k][index]) + return torch.cat(final_indices).squeeze() + + +def create_class_indices_mapping(labels): + unique_labels, inverse = torch.unique(labels, return_inverse=True) + mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))} + return mapping + + +class ModuleDictWithForward(torch.nn.ModuleDict): + def forward(self, *args, **kwargs): + return {k: module(*args, **kwargs) for k, module in self._modules.items()} + + +def eval_knn( + model, + train_dataset, + val_dataset, + accuracy_averaging, + nb_knn, + temperature, + batch_size, + num_workers, + gather_on_cpu, + n_per_class_list=[-1], + n_tries=1, +): + model = ModelWithNormalize(model) + + logger.info("Extracting features for train set...") + train_features, train_labels = extract_features( + model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu + ) + logger.info(f"Train features created, shape {train_features.shape}.") + + val_dataloader = make_data_loader( + dataset=val_dataset, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + persistent_workers=True, + ) + num_classes = train_labels.max() + 1 + metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes) + + device = torch.cuda.current_device() + partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes) + knn_module_dict = create_module_dict( + module=partial_module, + n_per_class_list=n_per_class_list, + n_tries=n_tries, + nb_knn=nb_knn, + train_features=train_features, + train_labels=train_labels, + ) + postprocessors, metrics = {}, {} + for n_per_class, knn_module in knn_module_dict.items(): + for t, knn_try in knn_module.items(): + postprocessors = { + **postprocessors, + **{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn}, + } + metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}} + model_with_knn = torch.nn.Sequential(model, knn_module_dict) + + # ============ evaluation ... ============ + logger.info("Start the k-NN classification.") + _, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device) + + # Averaging the results over the n tries for each value of n_per_class + for n_per_class, knn_module in knn_module_dict.items(): + first_try = list(knn_module.keys())[0] + k_list = knn_module[first_try].nb_knn + for k in k_list: + keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5` + results_dict[(n_per_class, k)] = { + key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()])) + for key in keys + } + for t in knn_module.keys(): + del results_dict[(n_per_class, t, k)] + + return results_dict + + +def eval_knn_with_model( + model, + output_dir, + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + nb_knn=(10, 20, 100, 200), + temperature=0.07, + autocast_dtype=torch.float, + accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, + transform=None, + gather_on_cpu=False, + batch_size=256, + num_workers=5, + n_per_class_list=[-1], + n_tries=1, +): + transform = transform or make_classification_eval_transform() + + train_dataset = make_dataset( + dataset_str=train_dataset_str, + transform=transform, + ) + val_dataset = make_dataset( + dataset_str=val_dataset_str, + transform=transform, + ) + + with torch.cuda.amp.autocast(dtype=autocast_dtype): + results_dict_knn = eval_knn( + model=model, + train_dataset=train_dataset, + val_dataset=val_dataset, + accuracy_averaging=accuracy_averaging, + nb_knn=nb_knn, + temperature=temperature, + batch_size=batch_size, + num_workers=num_workers, + gather_on_cpu=gather_on_cpu, + n_per_class_list=n_per_class_list, + n_tries=n_tries, + ) + + results_dict = {} + if distributed.is_main_process(): + for knn_ in results_dict_knn.keys(): + top1 = results_dict_knn[knn_]["top-1"].item() * 100.0 + top5 = results_dict_knn[knn_]["top-5"].item() * 100.0 + results_dict[f"{knn_} Top 1"] = top1 + results_dict[f"{knn_} Top 5"] = top5 + logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}") + + metrics_file_path = os.path.join(output_dir, "results_eval_knn.json") + with open(metrics_file_path, "a") as f: + for k, v in results_dict.items(): + f.write(json.dumps({k: v}) + "\n") + + if distributed.is_enabled(): + torch.distributed.barrier() + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + eval_knn_with_model( + model=model, + output_dir=args.output_dir, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + nb_knn=args.nb_knn, + temperature=args.temperature, + autocast_dtype=autocast_dtype, + accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, + transform=None, + gather_on_cpu=args.gather_on_cpu, + batch_size=args.batch_size, + num_workers=5, + n_per_class_list=args.n_per_class_list, + n_tries=args.n_tries, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 k-NN evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/src/dinov2/eval/linear.py b/src/dinov2/eval/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd4c5de5a041be8a188f007257d1e91b6d6921e --- /dev/null +++ b/src/dinov2/eval/linear.py @@ -0,0 +1,625 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +from functools import partial +import json +import logging +import os +import sys +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel +from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer + +from dinov2.data import SamplerType, make_data_loader, make_dataset +from dinov2.data.transforms import make_classification_eval_transform, make_classification_train_transform +import dinov2.distributed as distributed +from dinov2.eval.metrics import MetricType, build_metric +from dinov2.eval.setup import get_args_parser as get_setup_args_parser +from dinov2.eval.setup import setup_and_build_model +from dinov2.eval.utils import ModelWithIntermediateLayers, evaluate +from dinov2.logging import MetricLogger + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--test-datasets", + dest="test_dataset_strs", + type=str, + nargs="+", + help="Test datasets, none to reuse the validation dataset", + ) + parser.add_argument( + "--epochs", + type=int, + help="Number of training epochs", + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch Size (per GPU)", + ) + parser.add_argument( + "--num-workers", + type=int, + help="Number de Workers", + ) + parser.add_argument( + "--epoch-length", + type=int, + help="Length of an epoch in number of iterations", + ) + parser.add_argument( + "--save-checkpoint-frequency", + type=int, + help="Number of epochs between two named checkpoint saves.", + ) + parser.add_argument( + "--eval-period-iterations", + type=int, + help="Number of iterations between two evaluations.", + ) + parser.add_argument( + "--learning-rates", + nargs="+", + type=float, + help="Learning rates to grid search.", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Whether to not resume from existing checkpoints", + ) + parser.add_argument( + "--val-metric-type", + type=MetricType, + choices=list(MetricType), + help="Validation metric", + ) + parser.add_argument( + "--test-metric-types", + type=MetricType, + choices=list(MetricType), + nargs="+", + help="Evaluation metric", + ) + parser.add_argument( + "--classifier-fpath", + type=str, + help="Path to a file containing pretrained linear classifiers", + ) + parser.add_argument( + "--val-class-mapping-fpath", + type=str, + help="Path to a file containing a mapping to adjust classifier outputs", + ) + parser.add_argument( + "--test-class-mapping-fpaths", + nargs="+", + type=str, + help="Path to a file containing a mapping to adjust classifier outputs", + ) + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + test_dataset_strs=None, + epochs=10, + batch_size=128, + num_workers=8, + epoch_length=1250, + save_checkpoint_frequency=20, + eval_period_iterations=1250, + learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 0.1], + val_metric_type=MetricType.MEAN_ACCURACY, + test_metric_types=None, + classifier_fpath=None, + val_class_mapping_fpath=None, + test_class_mapping_fpaths=[None], + ) + return parser + + +def has_ddp_wrapper(m: nn.Module) -> bool: + return isinstance(m, DistributedDataParallel) + + +def remove_ddp_wrapper(m: nn.Module) -> nn.Module: + return m.module if has_ddp_wrapper(m) else m + + +def _pad_and_collate(batch): + maxlen = max(len(targets) for image, targets in batch) + padded_batch = [ + (image, np.pad(targets, (0, maxlen - len(targets)), constant_values=-1)) for image, targets in batch + ] + return torch.utils.data.default_collate(padded_batch) + + +def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool): + intermediate_output = x_tokens_list[-use_n_blocks:] + output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1) + if use_avgpool: + output = torch.cat( + ( + output, + torch.mean(intermediate_output[-1][0], dim=1), # patch tokens + ), + dim=-1, + ) + output = output.reshape(output.shape[0], -1) + return output.float() + + +class LinearClassifier(nn.Module): + """Linear layer to train on top of frozen features""" + + def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000): + super().__init__() + self.out_dim = out_dim + self.use_n_blocks = use_n_blocks + self.use_avgpool = use_avgpool + self.num_classes = num_classes + self.linear = nn.Linear(out_dim, num_classes) + self.linear.weight.data.normal_(mean=0.0, std=0.01) + self.linear.bias.data.zero_() + + def forward(self, x_tokens_list): + output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool) + return self.linear(output) + + +class AllClassifiers(nn.Module): + def __init__(self, classifiers_dict): + super().__init__() + self.classifiers_dict = nn.ModuleDict() + self.classifiers_dict.update(classifiers_dict) + + def forward(self, inputs): + return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()} + + def __len__(self): + return len(self.classifiers_dict) + + +class LinearPostprocessor(nn.Module): + def __init__(self, linear_classifier, class_mapping=None): + super().__init__() + self.linear_classifier = linear_classifier + self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping)) + + def forward(self, samples, targets): + preds = self.linear_classifier(samples) + return { + "preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds, + "target": targets, + } + + +def scale_lr(learning_rates, batch_size): + return learning_rates * (batch_size * distributed.get_global_size()) / 256.0 + + +def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, batch_size, num_classes=1000): + linear_classifiers_dict = nn.ModuleDict() + optim_param_groups = [] + for n in n_last_blocks_list: + for avgpool in [False, True]: + for _lr in learning_rates: + lr = scale_lr(_lr, batch_size) + out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1] + linear_classifier = LinearClassifier( + out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes + ) + linear_classifier = linear_classifier.cuda() + linear_classifiers_dict[ + f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}".replace(".", "_") + ] = linear_classifier + optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr}) + + linear_classifiers = AllClassifiers(linear_classifiers_dict) + if distributed.is_enabled(): + linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers) + + return linear_classifiers, optim_param_groups + + +@torch.no_grad() +def evaluate_linear_classifiers( + feature_model, + linear_classifiers, + data_loader, + metric_type, + metrics_file_path, + training_num_classes, + iteration, + prefixstring="", + class_mapping=None, + best_classifier_on_val=None, +): + logger.info("running validation !") + + num_classes = len(class_mapping) if class_mapping is not None else training_num_classes + metric = build_metric(metric_type, num_classes=num_classes) + postprocessors = {k: LinearPostprocessor(v, class_mapping) for k, v in linear_classifiers.classifiers_dict.items()} + metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict} + + _, results_dict_temp = evaluate( + feature_model, + data_loader, + postprocessors, + metrics, + torch.cuda.current_device(), + ) + + logger.info("") + results_dict = {} + max_accuracy = 0 + best_classifier = "" + for i, (classifier_string, metric) in enumerate(results_dict_temp.items()): + logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}") + if ( + best_classifier_on_val is None and metric["top-1"].item() > max_accuracy + ) or classifier_string == best_classifier_on_val: + max_accuracy = metric["top-1"].item() + best_classifier = classifier_string + + results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy} + + logger.info(f"best classifier: {results_dict['best_classifier']}") + + if distributed.is_main_process(): + with open(metrics_file_path, "a") as f: + f.write(f"iter: {iteration}\n") + for k, v in results_dict.items(): + f.write(json.dumps({k: v}) + "\n") + f.write("\n") + + return results_dict + + +def eval_linear( + *, + feature_model, + linear_classifiers, + train_data_loader, + val_data_loader, + metrics_file_path, + optimizer, + scheduler, + output_dir, + max_iter, + checkpoint_period, # In number of iter, creates a new file every period + running_checkpoint_period, # Period to update main checkpoint file + eval_period, + metric_type, + training_num_classes, + resume=True, + classifier_fpath=None, + val_class_mapping=None, +): + checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler) + start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1 + + periodic_checkpointer = PeriodicCheckpointer(checkpointer, checkpoint_period, max_iter=max_iter) + iteration = start_iter + logger.info("Starting training from iteration {}".format(start_iter)) + metric_logger = MetricLogger(delimiter=" ") + header = "Training" + + for data, labels in metric_logger.log_every( + train_data_loader, + 10, + header, + max_iter, + start_iter, + ): + data = data.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + features = feature_model(data) + outputs = linear_classifiers(features) + + losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()} + loss = sum(losses.values()) + + # compute the gradients + optimizer.zero_grad() + loss.backward() + + # step + optimizer.step() + scheduler.step() + + # log + if iteration % 10 == 0: + torch.cuda.synchronize() + metric_logger.update(loss=loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + print("lr", optimizer.param_groups[0]["lr"]) + + if iteration - start_iter > 5: + if iteration % running_checkpoint_period == 0: + torch.cuda.synchronize() + if distributed.is_main_process(): + logger.info("Checkpointing running_checkpoint") + periodic_checkpointer.save("running_checkpoint_linear_eval", iteration=iteration) + torch.cuda.synchronize() + periodic_checkpointer.step(iteration) + + if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1: + _ = evaluate_linear_classifiers( + feature_model=feature_model, + linear_classifiers=remove_ddp_wrapper(linear_classifiers), + data_loader=val_data_loader, + metrics_file_path=metrics_file_path, + prefixstring=f"ITER: {iteration}", + metric_type=metric_type, + training_num_classes=training_num_classes, + iteration=iteration, + class_mapping=val_class_mapping, + ) + torch.cuda.synchronize() + + iteration = iteration + 1 + + val_results_dict = evaluate_linear_classifiers( + feature_model=feature_model, + linear_classifiers=remove_ddp_wrapper(linear_classifiers), + data_loader=val_data_loader, + metrics_file_path=metrics_file_path, + metric_type=metric_type, + training_num_classes=training_num_classes, + iteration=iteration, + class_mapping=val_class_mapping, + ) + return val_results_dict, feature_model, linear_classifiers, iteration + + +def make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type): + test_dataset = make_dataset( + dataset_str=test_dataset_str, + transform=make_classification_eval_transform(), + ) + test_data_loader = make_data_loader( + dataset=test_dataset, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + persistent_workers=False, + collate_fn=_pad_and_collate if metric_type == MetricType.IMAGENET_REAL_ACCURACY else None, + ) + return test_data_loader + + +def test_on_datasets( + feature_model, + linear_classifiers, + test_dataset_strs, + batch_size, + num_workers, + test_metric_types, + metrics_file_path, + training_num_classes, + iteration, + best_classifier_on_val, + prefixstring="", + test_class_mappings=[None], +): + results_dict = {} + for test_dataset_str, class_mapping, metric_type in zip(test_dataset_strs, test_class_mappings, test_metric_types): + logger.info(f"Testing on {test_dataset_str}") + test_data_loader = make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type) + dataset_results_dict = evaluate_linear_classifiers( + feature_model, + remove_ddp_wrapper(linear_classifiers), + test_data_loader, + metric_type, + metrics_file_path, + training_num_classes, + iteration, + prefixstring="", + class_mapping=class_mapping, + best_classifier_on_val=best_classifier_on_val, + ) + results_dict[f"{test_dataset_str}_accuracy"] = 100.0 * dataset_results_dict["best_classifier"]["accuracy"] + return results_dict + + +def run_eval_linear( + model, + output_dir, + train_dataset_str, + val_dataset_str, + batch_size, + epochs, + epoch_length, + num_workers, + save_checkpoint_frequency, + eval_period_iterations, + learning_rates, + autocast_dtype, + test_dataset_strs=None, + resume=True, + classifier_fpath=None, + val_class_mapping_fpath=None, + test_class_mapping_fpaths=[None], + val_metric_type=MetricType.MEAN_ACCURACY, + test_metric_types=None, +): + seed = 0 + + if test_dataset_strs is None: + test_dataset_strs = [val_dataset_str] + if test_metric_types is None: + test_metric_types = [val_metric_type] * len(test_dataset_strs) + else: + assert len(test_metric_types) == len(test_dataset_strs) + assert len(test_dataset_strs) == len(test_class_mapping_fpaths) + + train_transform = make_classification_train_transform() + train_dataset = make_dataset( + dataset_str=train_dataset_str, + transform=train_transform, + ) + training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int)))) + sampler_type = SamplerType.SHARDED_INFINITE + # sampler_type = SamplerType.INFINITE + + n_last_blocks_list = [1, 4] + n_last_blocks = max(n_last_blocks_list) + autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype) + feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx) + sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda()) + + linear_classifiers, optim_param_groups = setup_linear_classifiers( + sample_output, + n_last_blocks_list, + learning_rates, + batch_size, + training_num_classes, + ) + + optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0) + max_iter = epochs * epoch_length + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0) + checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler) + start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1 + train_data_loader = make_data_loader( + dataset=train_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=True, + seed=seed, + sampler_type=sampler_type, + sampler_advance=start_iter, + drop_last=True, + persistent_workers=True, + ) + val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type) + + checkpoint_period = save_checkpoint_frequency * epoch_length + + if val_class_mapping_fpath is not None: + logger.info(f"Using class mapping from {val_class_mapping_fpath}") + val_class_mapping = np.load(val_class_mapping_fpath) + else: + val_class_mapping = None + + test_class_mappings = [] + for class_mapping_fpath in test_class_mapping_fpaths: + if class_mapping_fpath is not None and class_mapping_fpath != "None": + logger.info(f"Using class mapping from {class_mapping_fpath}") + class_mapping = np.load(class_mapping_fpath) + else: + class_mapping = None + test_class_mappings.append(class_mapping) + + metrics_file_path = os.path.join(output_dir, "results_eval_linear.json") + val_results_dict, feature_model, linear_classifiers, iteration = eval_linear( + feature_model=feature_model, + linear_classifiers=linear_classifiers, + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + metrics_file_path=metrics_file_path, + optimizer=optimizer, + scheduler=scheduler, + output_dir=output_dir, + max_iter=max_iter, + checkpoint_period=checkpoint_period, + running_checkpoint_period=epoch_length, + eval_period=eval_period_iterations, + metric_type=val_metric_type, + training_num_classes=training_num_classes, + resume=resume, + val_class_mapping=val_class_mapping, + classifier_fpath=classifier_fpath, + ) + results_dict = {} + if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str: + results_dict = test_on_datasets( + feature_model, + linear_classifiers, + test_dataset_strs, + batch_size, + 0, # num_workers, + test_metric_types, + metrics_file_path, + training_num_classes, + iteration, + val_results_dict["best_classifier"]["name"], + prefixstring="", + test_class_mappings=test_class_mappings, + ) + results_dict["best_classifier"] = val_results_dict["best_classifier"]["name"] + results_dict[f"{val_dataset_str}_accuracy"] = 100.0 * val_results_dict["best_classifier"]["accuracy"] + logger.info("Test Results Dict " + str(results_dict)) + + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + run_eval_linear( + model=model, + output_dir=args.output_dir, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + test_dataset_strs=args.test_dataset_strs, + batch_size=args.batch_size, + epochs=args.epochs, + epoch_length=args.epoch_length, + num_workers=args.num_workers, + save_checkpoint_frequency=args.save_checkpoint_frequency, + eval_period_iterations=args.eval_period_iterations, + learning_rates=args.learning_rates, + autocast_dtype=autocast_dtype, + resume=not args.no_resume, + classifier_fpath=args.classifier_fpath, + val_metric_type=args.val_metric_type, + test_metric_types=args.test_metric_types, + val_class_mapping_fpath=args.val_class_mapping_fpath, + test_class_mapping_fpaths=args.test_class_mapping_fpaths, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 linear evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/src/dinov2/eval/log_regression.py b/src/dinov2/eval/log_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..5f36ec134e0ce25697428a0b3f21cdc2f0145645 --- /dev/null +++ b/src/dinov2/eval/log_regression.py @@ -0,0 +1,444 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +import gc +import logging +import sys +import time +from typing import List, Optional + +from cuml.linear_model import LogisticRegression +import torch +import torch.backends.cudnn as cudnn +import torch.distributed +from torch import nn +from torch.utils.data import TensorDataset +from torchmetrics import MetricTracker + +from dinov2.data import make_dataset +from dinov2.data.transforms import make_classification_eval_transform +from dinov2.distributed import get_global_rank, get_global_size +from dinov2.eval.metrics import MetricType, build_metric +from dinov2.eval.setup import get_args_parser as get_setup_args_parser +from dinov2.eval.setup import setup_and_build_model +from dinov2.eval.utils import evaluate, extract_features +from dinov2.utils.dtype import as_torch_dtype + + +logger = logging.getLogger("dinov2") + +DEFAULT_MAX_ITER = 1_000 +C_POWER_RANGE = torch.linspace(-6, 5, 45) +_CPU_DEVICE = torch.device("cpu") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--finetune-dataset-str", + dest="finetune_dataset_str", + type=str, + help="Fine-tuning dataset", + ) + parser.add_argument( + "--finetune-on-val", + action="store_true", + help="If there is no finetune dataset, whether to choose the " + "hyperparameters on the val set instead of 10%% of the train dataset", + ) + parser.add_argument( + "--metric-type", + type=MetricType, + choices=list(MetricType), + help="Metric type", + ) + parser.add_argument( + "--train-features-device", + type=str, + help="Device to gather train features (cpu, cuda, cuda:0, etc.), default: %(default)s", + ) + parser.add_argument( + "--train-dtype", + type=str, + help="Data type to convert the train features to (default: %(default)s)", + ) + parser.add_argument( + "--max-train-iters", + type=int, + help="Maximum number of train iterations (default: %(default)s)", + ) + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + finetune_dataset_str=None, + metric_type=MetricType.MEAN_ACCURACY, + train_features_device="cpu", + train_dtype="float64", + max_train_iters=DEFAULT_MAX_ITER, + finetune_on_val=False, + ) + return parser + + +class LogRegModule(nn.Module): + def __init__( + self, + C, + max_iter=DEFAULT_MAX_ITER, + dtype=torch.float64, + device=_CPU_DEVICE, + ): + super().__init__() + self.dtype = dtype + self.device = device + self.estimator = LogisticRegression( + penalty="l2", + C=C, + max_iter=max_iter, + output_type="numpy", + tol=1e-12, + linesearch_max_iter=50, + ) + + def forward(self, samples, targets): + samples_device = samples.device + samples = samples.to(dtype=self.dtype, device=self.device) + if self.device == _CPU_DEVICE: + samples = samples.numpy() + probas = self.estimator.predict_proba(samples) + return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets} + + def fit(self, train_features, train_labels): + train_features = train_features.to(dtype=self.dtype, device=self.device) + train_labels = train_labels.to(dtype=self.dtype, device=self.device) + if self.device == _CPU_DEVICE: + # both cuML and sklearn only work with numpy arrays on CPU + train_features = train_features.numpy() + train_labels = train_labels.numpy() + self.estimator.fit(train_features, train_labels) + + +def evaluate_model(*, logreg_model, logreg_metric, test_data_loader, device): + postprocessors = {"metrics": logreg_model} + metrics = {"metrics": logreg_metric} + return evaluate(nn.Identity(), test_data_loader, postprocessors, metrics, device) + + +def train_for_C(*, C, max_iter, train_features, train_labels, dtype=torch.float64, device=_CPU_DEVICE): + logreg_model = LogRegModule(C, max_iter=max_iter, dtype=dtype, device=device) + logreg_model.fit(train_features, train_labels) + return logreg_model + + +def train_and_evaluate( + *, + C, + max_iter, + train_features, + train_labels, + logreg_metric, + test_data_loader, + train_dtype=torch.float64, + train_features_device, + eval_device, +): + logreg_model = train_for_C( + C=C, + max_iter=max_iter, + train_features=train_features, + train_labels=train_labels, + dtype=train_dtype, + device=train_features_device, + ) + return evaluate_model( + logreg_model=logreg_model, + logreg_metric=logreg_metric, + test_data_loader=test_data_loader, + device=eval_device, + ) + + +def sweep_C_values( + *, + train_features, + train_labels, + test_data_loader, + metric_type, + num_classes, + train_dtype=torch.float64, + train_features_device=_CPU_DEVICE, + max_train_iters=DEFAULT_MAX_ITER, +): + if metric_type == MetricType.PER_CLASS_ACCURACY: + # If we want to output per-class accuracy, we select the hyperparameters with mean per class + metric_type = MetricType.MEAN_PER_CLASS_ACCURACY + logreg_metric = build_metric(metric_type, num_classes=num_classes) + metric_tracker = MetricTracker(logreg_metric, maximize=True) + ALL_C = 10**C_POWER_RANGE + logreg_models = {} + + train_features = train_features.to(dtype=train_dtype, device=train_features_device) + train_labels = train_labels.to(device=train_features_device) + + for i in range(get_global_rank(), len(ALL_C), get_global_size()): + C = ALL_C[i].item() + logger.info( + f"Training for C = {C:.5f}, dtype={train_dtype}, " + f"features: {train_features.shape}, {train_features.dtype}, " + f"labels: {train_labels.shape}, {train_labels.dtype}" + ) + logreg_models[C] = train_for_C( + C=C, + max_iter=max_train_iters, + train_features=train_features, + train_labels=train_labels, + dtype=train_dtype, + device=train_features_device, + ) + + gather_list = [None for _ in range(get_global_size())] + torch.distributed.all_gather_object(gather_list, logreg_models) + + logreg_models_gathered = {} + for logreg_dict in gather_list: + logreg_models_gathered.update(logreg_dict) + + for i in range(len(ALL_C)): + metric_tracker.increment() + C = ALL_C[i].item() + evals = evaluate_model( + logreg_model=logreg_models_gathered[C], + logreg_metric=metric_tracker, + test_data_loader=test_data_loader, + device=torch.cuda.current_device(), + ) + logger.info(f"Trained for C = {C:.5f}, accuracies = {evals}") + + best_stats, which_epoch = metric_tracker.best_metric(return_step=True) + best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()} + if which_epoch["top-1"] == i: + best_C = C + logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.6f}") + + return best_stats, best_C + + +def eval_log_regression( + *, + model, + train_dataset, + val_dataset, + finetune_dataset, + metric_type, + batch_size, + num_workers, + finetune_on_val=False, + train_dtype=torch.float64, + train_features_device=_CPU_DEVICE, + max_train_iters=DEFAULT_MAX_ITER, +): + """ + Implements the "standard" process for log regression evaluation: + The value of C is chosen by training on train_dataset and evaluating on + finetune_dataset. Then, the final model is trained on a concatenation of + train_dataset and finetune_dataset, and is evaluated on val_dataset. + If there is no finetune_dataset, the value of C is the one that yields + the best results on a random 10% subset of the train dataset + """ + + start = time.time() + + train_features, train_labels = extract_features( + model, train_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) + ) + val_features, val_labels = extract_features( + model, val_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) + ) + val_data_loader = torch.utils.data.DataLoader( + TensorDataset(val_features, val_labels), + batch_size=batch_size, + drop_last=False, + num_workers=0, + persistent_workers=False, + ) + + if finetune_dataset is None and finetune_on_val: + logger.info("Choosing hyperparameters on the val dataset") + finetune_features, finetune_labels = val_features, val_labels + elif finetune_dataset is None and not finetune_on_val: + logger.info("Choosing hyperparameters on 10% of the train dataset") + torch.manual_seed(0) + indices = torch.randperm(len(train_features), device=train_features.device) + finetune_index = indices[: len(train_features) // 10] + train_index = indices[len(train_features) // 10 :] + finetune_features, finetune_labels = train_features[finetune_index], train_labels[finetune_index] + train_features, train_labels = train_features[train_index], train_labels[train_index] + else: + logger.info("Choosing hyperparameters on the finetune dataset") + finetune_features, finetune_labels = extract_features( + model, finetune_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) + ) + # release the model - free GPU memory + del model + gc.collect() + torch.cuda.empty_cache() + finetune_data_loader = torch.utils.data.DataLoader( + TensorDataset(finetune_features, finetune_labels), + batch_size=batch_size, + drop_last=False, + ) + + if len(train_labels.shape) > 1: + num_classes = train_labels.shape[1] + else: + num_classes = train_labels.max() + 1 + + logger.info("Using cuML for logistic regression") + + best_stats, best_C = sweep_C_values( + train_features=train_features, + train_labels=train_labels, + test_data_loader=finetune_data_loader, + metric_type=metric_type, + num_classes=num_classes, + train_dtype=train_dtype, + train_features_device=train_features_device, + max_train_iters=max_train_iters, + ) + + if not finetune_on_val: + logger.info("Best parameter found, concatenating features") + train_features = torch.cat((train_features, finetune_features)) + train_labels = torch.cat((train_labels, finetune_labels)) + + logger.info("Training final model") + logreg_metric = build_metric(metric_type, num_classes=num_classes) + evals = train_and_evaluate( + C=best_C, + max_iter=max_train_iters, + train_features=train_features, + train_labels=train_labels, + logreg_metric=logreg_metric.clone(), + test_data_loader=val_data_loader, + eval_device=torch.cuda.current_device(), + train_dtype=train_dtype, + train_features_device=train_features_device, + ) + + best_stats = evals[1]["metrics"] + + best_stats["best_C"] = best_C + + logger.info(f"Log regression evaluation done in {int(time.time() - start)}s") + return best_stats + + +def eval_log_regression_with_model( + model, + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + finetune_dataset_str=None, + autocast_dtype=torch.float, + finetune_on_val=False, + metric_type=MetricType.MEAN_ACCURACY, + train_dtype=torch.float64, + train_features_device=_CPU_DEVICE, + max_train_iters=DEFAULT_MAX_ITER, +): + cudnn.benchmark = True + + transform = make_classification_eval_transform(resize_size=224) + target_transform = None + + train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform, target_transform=target_transform) + val_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform, target_transform=target_transform) + if finetune_dataset_str is not None: + finetune_dataset = make_dataset( + dataset_str=finetune_dataset_str, transform=transform, target_transform=target_transform + ) + else: + finetune_dataset = None + + with torch.cuda.amp.autocast(dtype=autocast_dtype): + results_dict_logreg = eval_log_regression( + model=model, + train_dataset=train_dataset, + val_dataset=val_dataset, + finetune_dataset=finetune_dataset, + metric_type=metric_type, + batch_size=256, + num_workers=0, # 5, + finetune_on_val=finetune_on_val, + train_dtype=train_dtype, + train_features_device=train_features_device, + max_train_iters=max_train_iters, + ) + + results_dict = { + "top-1": results_dict_logreg["top-1"].cpu().numpy() * 100.0, + "top-5": results_dict_logreg.get("top-5", torch.tensor(0.0)).cpu().numpy() * 100.0, + "best_C": results_dict_logreg["best_C"], + } + logger.info( + "\n".join( + [ + "Training of the supervised logistic regression on frozen features completed.\n" + "Top-1 test accuracy: {acc:.1f}".format(acc=results_dict["top-1"]), + "Top-5 test accuracy: {acc:.1f}".format(acc=results_dict["top-5"]), + "obtained for C = {c:.6f}".format(c=results_dict["best_C"]), + ] + ) + ) + + torch.distributed.barrier() + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + eval_log_regression_with_model( + model=model, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + finetune_dataset_str=args.finetune_dataset_str, + autocast_dtype=autocast_dtype, + finetune_on_val=args.finetune_on_val, + metric_type=args.metric_type, + train_dtype=as_torch_dtype(args.train_dtype), + train_features_device=torch.device(args.train_features_device), + max_train_iters=args.max_train_iters, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 logistic regression evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/src/dinov2/eval/metrics.py b/src/dinov2/eval/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..52be81a859dddde82da93c3657c35352d2bb0a48 --- /dev/null +++ b/src/dinov2/eval/metrics.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import logging +from typing import Any, Dict, Optional + +import torch +from torch import Tensor +from torchmetrics import Metric, MetricCollection +from torchmetrics.classification import MulticlassAccuracy +from torchmetrics.utilities.data import dim_zero_cat, select_topk + + +logger = logging.getLogger("dinov2") + + +class MetricType(Enum): + MEAN_ACCURACY = "mean_accuracy" + MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" + PER_CLASS_ACCURACY = "per_class_accuracy" + IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy" + + @property + def accuracy_averaging(self): + return getattr(AccuracyAveraging, self.name, None) + + def __str__(self): + return self.value + + +class AccuracyAveraging(Enum): + MEAN_ACCURACY = "micro" + MEAN_PER_CLASS_ACCURACY = "macro" + PER_CLASS_ACCURACY = "none" + + def __str__(self): + return self.value + + +def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None): + if metric_type.accuracy_averaging is not None: + return build_topk_accuracy_metric( + average_type=metric_type.accuracy_averaging, + num_classes=num_classes, + ks=(1, 5) if ks is None else ks, + ) + elif metric_type == MetricType.IMAGENET_REAL_ACCURACY: + return build_topk_imagenet_real_accuracy_metric( + num_classes=num_classes, + ks=(1, 5) if ks is None else ks, + ) + + raise ValueError(f"Unknown metric type {metric_type}") + + +def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)): + metrics: Dict[str, Metric] = { + f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks + } + return MetricCollection(metrics) + + +def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)): + metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks} + return MetricCollection(metrics) + + +class ImageNetReaLAccuracy(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + top_k: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.num_classes = num_classes + self.top_k = top_k + self.add_state("tp", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + # preds [B, D] + # target [B, A] + # preds_oh [B, D] with 0 and 1 + # select top K highest probabilities, use one hot representation + preds_oh = select_topk(preds, self.top_k) + # target_oh [B, D + 1] with 0 and 1 + target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32) + target = target.long() + # for undefined targets (-1) use a fake value `num_classes` + target[target == -1] = self.num_classes + # fill targets, use one hot representation + target_oh.scatter_(1, target, 1) + # target_oh [B, D] (remove the fake target at index `num_classes`) + target_oh = target_oh[:, :-1] + # tp [B] with 0 and 1 + tp = (preds_oh * target_oh == 1).sum(dim=1) + # at least one match between prediction and target + tp.clip_(max=1) + # ignore instances where no targets are defined + mask = target_oh.sum(dim=1) > 0 + tp = tp[mask] + self.tp.append(tp) # type: ignore + + def compute(self) -> Tensor: + tp = dim_zero_cat(self.tp) # type: ignore + return tp.float().mean() diff --git a/src/dinov2/eval/segmentation/__init__.py b/src/dinov2/eval/segmentation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/src/dinov2/eval/segmentation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/dinov2/eval/segmentation/hooks/__init__.py b/src/dinov2/eval/segmentation/hooks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..738cc2d2069521ea0353acd0cb0a03e3ddf1fa51 --- /dev/null +++ b/src/dinov2/eval/segmentation/hooks/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .optimizer import DistOptimizerHook diff --git a/src/dinov2/eval/segmentation/hooks/optimizer.py b/src/dinov2/eval/segmentation/hooks/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f593f26a84475bbf7ebda9607a4d10914b13a443 --- /dev/null +++ b/src/dinov2/eval/segmentation/hooks/optimizer.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +try: + import apex +except ImportError: + print("apex is not installed") + +from mmcv.runner import OptimizerHook, HOOKS + + +@HOOKS.register_module() +class DistOptimizerHook(OptimizerHook): + """Optimizer hook for distributed training.""" + + def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.update_interval = update_interval + self.use_fp16 = use_fp16 + + def before_run(self, runner): + runner.optimizer.zero_grad() + + def after_train_iter(self, runner): + runner.outputs["loss"] /= self.update_interval + if self.use_fp16: + # runner.outputs['loss'].backward() + with apex.amp.scale_loss(runner.outputs["loss"], runner.optimizer) as scaled_loss: + scaled_loss.backward() + else: + runner.outputs["loss"].backward() + if self.every_n_iters(runner, self.update_interval): + if self.grad_clip is not None: + self.clip_grads(runner.model.parameters()) + runner.optimizer.step() + runner.optimizer.zero_grad() diff --git a/src/dinov2/eval/segmentation/models/__init__.py b/src/dinov2/eval/segmentation/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88e4563d4c162d67e7900955a06bd9248d4c9a48 --- /dev/null +++ b/src/dinov2/eval/segmentation/models/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .backbones import * # noqa: F403 +from .decode_heads import * # noqa: F403 diff --git a/src/dinov2/eval/segmentation/models/backbones/__init__.py b/src/dinov2/eval/segmentation/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..520d75bc6e064b9d64487293604ac1bda6e2b6f7 --- /dev/null +++ b/src/dinov2/eval/segmentation/models/backbones/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .vision_transformer import DinoVisionTransformer diff --git a/src/dinov2/eval/segmentation/models/backbones/vision_transformer.py b/src/dinov2/eval/segmentation/models/backbones/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e9753ae92a36be52f100e3004cbeeff777d14a --- /dev/null +++ b/src/dinov2/eval/segmentation/models/backbones/vision_transformer.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.runner import BaseModule +from mmseg.models.builder import BACKBONES + + +@BACKBONES.register_module() +class DinoVisionTransformer(BaseModule): + """Vision Transformer.""" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__() diff --git a/src/dinov2/eval/segmentation/models/decode_heads/__init__.py b/src/dinov2/eval/segmentation/models/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c55317875262dadf8970c2b3882f016b8d4731ac --- /dev/null +++ b/src/dinov2/eval/segmentation/models/decode_heads/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .linear_head import BNHead diff --git a/src/dinov2/eval/segmentation/models/decode_heads/linear_head.py b/src/dinov2/eval/segmentation/models/decode_heads/linear_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d1f39c68fb136f84d1aa5284da5b69581bb177cc --- /dev/null +++ b/src/dinov2/eval/segmentation/models/decode_heads/linear_head.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from mmseg.models.builder import HEADS +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.ops import resize + + +@HEADS.register_module() +class BNHead(BaseDecodeHead): + """Just a batchnorm.""" + + def __init__(self, resize_factors=None, **kwargs): + super().__init__(**kwargs) + assert self.in_channels == self.channels + self.bn = nn.SyncBatchNorm(self.in_channels) + self.resize_factors = resize_factors + + def _forward_feature(self, inputs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # print("inputs", [i.shape for i in inputs]) + x = self._transform_inputs(inputs) + # print("x", x.shape) + feats = self.bn(x) + # print("feats", feats.shape) + return feats + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + + if self.input_transform == "resize_concat": + # accept lists (for cls token) + input_list = [] + for x in inputs: + if isinstance(x, list): + input_list.extend(x) + else: + input_list.append(x) + inputs = input_list + # an image descriptor can be a local descriptor with resolution 1x1 + for i, x in enumerate(inputs): + if len(x.shape) == 2: + inputs[i] = x[:, :, None, None] + # select indices + inputs = [inputs[i] for i in self.in_index] + # Resizing shenanigans + # print("before", *(x.shape for x in inputs)) + if self.resize_factors is not None: + assert len(self.resize_factors) == len(inputs), (len(self.resize_factors), len(inputs)) + inputs = [ + resize(input=x, scale_factor=f, mode="bilinear" if f >= 1 else "area") + for x, f in zip(inputs, self.resize_factors) + ] + # print("after", *(x.shape for x in inputs)) + upsampled_inputs = [ + resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners) + for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def forward(self, inputs): + """Forward function.""" + output = self._forward_feature(inputs) + output = self.cls_seg(output) + return output diff --git a/src/dinov2/eval/segmentation/utils/__init__.py b/src/dinov2/eval/segmentation/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/src/dinov2/eval/segmentation/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/dinov2/eval/segmentation/utils/colormaps.py b/src/dinov2/eval/segmentation/utils/colormaps.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ef604b2c75792e95e438abfd51ab03d40de340 --- /dev/null +++ b/src/dinov2/eval/segmentation/utils/colormaps.py @@ -0,0 +1,362 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +ADE20K_COLORMAP = [ + (0, 0, 0), + (120, 120, 120), + (180, 120, 120), + (6, 230, 230), + (80, 50, 50), + (4, 200, 3), + (120, 120, 80), + (140, 140, 140), + (204, 5, 255), + (230, 230, 230), + (4, 250, 7), + (224, 5, 255), + (235, 255, 7), + (150, 5, 61), + (120, 120, 70), + (8, 255, 51), + (255, 6, 82), + (143, 255, 140), + (204, 255, 4), + (255, 51, 7), + (204, 70, 3), + (0, 102, 200), + (61, 230, 250), + (255, 6, 51), + (11, 102, 255), + (255, 7, 71), + (255, 9, 224), + (9, 7, 230), + (220, 220, 220), + (255, 9, 92), + (112, 9, 255), + (8, 255, 214), + (7, 255, 224), + (255, 184, 6), + (10, 255, 71), + (255, 41, 10), + (7, 255, 255), + (224, 255, 8), + (102, 8, 255), + (255, 61, 6), + (255, 194, 7), + (255, 122, 8), + (0, 255, 20), + (255, 8, 41), + (255, 5, 153), + (6, 51, 255), + (235, 12, 255), + (160, 150, 20), + (0, 163, 255), + (140, 140, 140), + (250, 10, 15), + (20, 255, 0), + (31, 255, 0), + (255, 31, 0), + (255, 224, 0), + (153, 255, 0), + (0, 0, 255), + (255, 71, 0), + (0, 235, 255), + (0, 173, 255), + (31, 0, 255), + (11, 200, 200), + (255, 82, 0), + (0, 255, 245), + (0, 61, 255), + (0, 255, 112), + (0, 255, 133), + (255, 0, 0), + (255, 163, 0), + (255, 102, 0), + (194, 255, 0), + (0, 143, 255), + (51, 255, 0), + (0, 82, 255), + (0, 255, 41), + (0, 255, 173), + (10, 0, 255), + (173, 255, 0), + (0, 255, 153), + (255, 92, 0), + (255, 0, 255), + (255, 0, 245), + (255, 0, 102), + (255, 173, 0), + (255, 0, 20), + (255, 184, 184), + (0, 31, 255), + (0, 255, 61), + (0, 71, 255), + (255, 0, 204), + (0, 255, 194), + (0, 255, 82), + (0, 10, 255), + (0, 112, 255), + (51, 0, 255), + (0, 194, 255), + (0, 122, 255), + (0, 255, 163), + (255, 153, 0), + (0, 255, 10), + (255, 112, 0), + (143, 255, 0), + (82, 0, 255), + (163, 255, 0), + (255, 235, 0), + (8, 184, 170), + (133, 0, 255), + (0, 255, 92), + (184, 0, 255), + (255, 0, 31), + (0, 184, 255), + (0, 214, 255), + (255, 0, 112), + (92, 255, 0), + (0, 224, 255), + (112, 224, 255), + (70, 184, 160), + (163, 0, 255), + (153, 0, 255), + (71, 255, 0), + (255, 0, 163), + (255, 204, 0), + (255, 0, 143), + (0, 255, 235), + (133, 255, 0), + (255, 0, 235), + (245, 0, 255), + (255, 0, 122), + (255, 245, 0), + (10, 190, 212), + (214, 255, 0), + (0, 204, 255), + (20, 0, 255), + (255, 255, 0), + (0, 153, 255), + (0, 41, 255), + (0, 255, 204), + (41, 0, 255), + (41, 255, 0), + (173, 0, 255), + (0, 245, 255), + (71, 0, 255), + (122, 0, 255), + (0, 255, 184), + (0, 92, 255), + (184, 255, 0), + (0, 133, 255), + (255, 214, 0), + (25, 194, 194), + (102, 255, 0), + (92, 0, 255), +] + +ADE20K_CLASS_NAMES = [ + "", + "wall", + "building;edifice", + "sky", + "floor;flooring", + "tree", + "ceiling", + "road;route", + "bed", + "windowpane;window", + "grass", + "cabinet", + "sidewalk;pavement", + "person;individual;someone;somebody;mortal;soul", + "earth;ground", + "door;double;door", + "table", + "mountain;mount", + "plant;flora;plant;life", + "curtain;drape;drapery;mantle;pall", + "chair", + "car;auto;automobile;machine;motorcar", + "water", + "painting;picture", + "sofa;couch;lounge", + "shelf", + "house", + "sea", + "mirror", + "rug;carpet;carpeting", + "field", + "armchair", + "seat", + "fence;fencing", + "desk", + "rock;stone", + "wardrobe;closet;press", + "lamp", + "bathtub;bathing;tub;bath;tub", + "railing;rail", + "cushion", + "base;pedestal;stand", + "box", + "column;pillar", + "signboard;sign", + "chest;of;drawers;chest;bureau;dresser", + "counter", + "sand", + "sink", + "skyscraper", + "fireplace;hearth;open;fireplace", + "refrigerator;icebox", + "grandstand;covered;stand", + "path", + "stairs;steps", + "runway", + "case;display;case;showcase;vitrine", + "pool;table;billiard;table;snooker;table", + "pillow", + "screen;door;screen", + "stairway;staircase", + "river", + "bridge;span", + "bookcase", + "blind;screen", + "coffee;table;cocktail;table", + "toilet;can;commode;crapper;pot;potty;stool;throne", + "flower", + "book", + "hill", + "bench", + "countertop", + "stove;kitchen;stove;range;kitchen;range;cooking;stove", + "palm;palm;tree", + "kitchen;island", + "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system", + "swivel;chair", + "boat", + "bar", + "arcade;machine", + "hovel;hut;hutch;shack;shanty", + "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle", + "towel", + "light;light;source", + "truck;motortruck", + "tower", + "chandelier;pendant;pendent", + "awning;sunshade;sunblind", + "streetlight;street;lamp", + "booth;cubicle;stall;kiosk", + "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box", + "airplane;aeroplane;plane", + "dirt;track", + "apparel;wearing;apparel;dress;clothes", + "pole", + "land;ground;soil", + "bannister;banister;balustrade;balusters;handrail", + "escalator;moving;staircase;moving;stairway", + "ottoman;pouf;pouffe;puff;hassock", + "bottle", + "buffet;counter;sideboard", + "poster;posting;placard;notice;bill;card", + "stage", + "van", + "ship", + "fountain", + "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter", + "canopy", + "washer;automatic;washer;washing;machine", + "plaything;toy", + "swimming;pool;swimming;bath;natatorium", + "stool", + "barrel;cask", + "basket;handbasket", + "waterfall;falls", + "tent;collapsible;shelter", + "bag", + "minibike;motorbike", + "cradle", + "oven", + "ball", + "food;solid;food", + "step;stair", + "tank;storage;tank", + "trade;name;brand;name;brand;marque", + "microwave;microwave;oven", + "pot;flowerpot", + "animal;animate;being;beast;brute;creature;fauna", + "bicycle;bike;wheel;cycle", + "lake", + "dishwasher;dish;washer;dishwashing;machine", + "screen;silver;screen;projection;screen", + "blanket;cover", + "sculpture", + "hood;exhaust;hood", + "sconce", + "vase", + "traffic;light;traffic;signal;stoplight", + "tray", + "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin", + "fan", + "pier;wharf;wharfage;dock", + "crt;screen", + "plate", + "monitor;monitoring;device", + "bulletin;board;notice;board", + "shower", + "radiator", + "glass;drinking;glass", + "clock", + "flag", +] + + +VOC2012_COLORMAP = [ + (0, 0, 0), + (128, 0, 0), + (0, 128, 0), + (128, 128, 0), + (0, 0, 128), + (128, 0, 128), + (0, 128, 128), + (128, 128, 128), + (64, 0, 0), + (192, 0, 0), + (64, 128, 0), + (192, 128, 0), + (64, 0, 128), + (192, 0, 128), + (64, 128, 128), + (192, 128, 128), + (0, 64, 0), + (128, 64, 0), + (0, 192, 0), + (128, 192, 0), + (0, 64, 128), +] + + +VOC2012_CLASS_NAMES = [ + "", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", +] diff --git a/src/dinov2/eval/segmentation_m2f/__init__.py b/src/dinov2/eval/segmentation_m2f/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c678fdf8f1dee14d7cf9be70af14e6f9a1441c3 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .core import * # noqa: F403 +from .models import * # noqa: F403 +from .ops import * # noqa: F403 diff --git a/src/dinov2/eval/segmentation_m2f/core/__init__.py b/src/dinov2/eval/segmentation_m2f/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92599806fbd221c1418d179892a0f46dc0b7d4db --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmseg.core.evaluation import * # noqa: F403 +from mmseg.core.seg import * # noqa: F403 + +from .anchor import * # noqa: F403 +from .box import * # noqa: F403 +from .utils import * # noqa: F403 diff --git a/src/dinov2/eval/segmentation_m2f/core/anchor/__init__.py b/src/dinov2/eval/segmentation_m2f/core/anchor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e71ac4d6e01462221ae01aa16d0e1231cda7e2e7 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/anchor/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .point_generator import MlvlPointGenerator # noqa: F403 diff --git a/src/dinov2/eval/segmentation_m2f/core/anchor/builder.py b/src/dinov2/eval/segmentation_m2f/core/anchor/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6dba90e22de76d2f23a86d3c057f196d55a99690 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/anchor/builder.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +from mmcv.utils import Registry, build_from_cfg + +PRIOR_GENERATORS = Registry("Generator for anchors and points") + +ANCHOR_GENERATORS = PRIOR_GENERATORS + + +def build_prior_generator(cfg, default_args=None): + return build_from_cfg(cfg, PRIOR_GENERATORS, default_args) + + +def build_anchor_generator(cfg, default_args=None): + warnings.warn("``build_anchor_generator`` would be deprecated soon, please use " "``build_prior_generator`` ") + return build_prior_generator(cfg, default_args=default_args) diff --git a/src/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py b/src/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..574d71939080e22284fe99087fb2e7336657bd97 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn.modules.utils import _pair + +from .builder import PRIOR_GENERATORS + + +@PRIOR_GENERATORS.register_module() +class MlvlPointGenerator: + """Standard points generator for multi-level (Mlvl) feature maps in 2D + points-based detectors. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels in order (w, h). + offset (float): The offset of points, the value is normalized with + corresponding stride. Defaults to 0.5. + """ + + def __init__(self, strides, offset=0.5): + self.strides = [_pair(stride) for stride in strides] + self.offset = offset + + @property + def num_levels(self): + """int: number of feature levels that the generator will be applied""" + return len(self.strides) + + @property + def num_base_priors(self): + """list[int]: The number of priors (points) at a point + on the feature grid""" + return [1 for _ in range(len(self.strides))] + + def _meshgrid(self, x, y, row_major=True): + yy, xx = torch.meshgrid(y, x) + if row_major: + # warning .flatten() would cause error in ONNX exporting + # have to use reshape here + return xx.reshape(-1), yy.reshape(-1) + + else: + return yy.reshape(-1), xx.reshape(-1) + + def grid_priors(self, featmap_sizes, dtype=torch.float32, device="cuda", with_stride=False): + """Generate grid points of multiple feature levels. + + Args: + featmap_sizes (list[tuple]): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32. + device (str): The device where the anchors will be put on. + with_stride (bool): Whether to concatenate the stride to + the last dimension of points. + + Return: + list[torch.Tensor]: Points of multiple feature levels. + The sizes of each tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + + assert self.num_levels == len(featmap_sizes) + multi_level_priors = [] + for i in range(self.num_levels): + priors = self.single_level_grid_priors( + featmap_sizes[i], level_idx=i, dtype=dtype, device=device, with_stride=with_stride + ) + multi_level_priors.append(priors) + return multi_level_priors + + def single_level_grid_priors(self, featmap_size, level_idx, dtype=torch.float32, device="cuda", with_stride=False): + """Generate grid Points of a single level. + + Note: + This function is usually called by method ``self.grid_priors``. + + Args: + featmap_size (tuple[int]): Size of the feature maps, arrange as + (h, w). + level_idx (int): The index of corresponding feature map level. + dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32. + device (str, optional): The device the tensor will be put on. + Defaults to 'cuda'. + with_stride (bool): Concatenate the stride to the last dimension + of points. + + Return: + Tensor: Points of single feature levels. + The shape of tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + feat_h, feat_w = featmap_size + stride_w, stride_h = self.strides[level_idx] + shift_x = (torch.arange(0, feat_w, device=device) + self.offset) * stride_w + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_x = shift_x.to(dtype) + + shift_y = (torch.arange(0, feat_h, device=device) + self.offset) * stride_h + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_y = shift_y.to(dtype) + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + if not with_stride: + shifts = torch.stack([shift_xx, shift_yy], dim=-1) + else: + # use `shape[0]` instead of `len(shift_xx)` for ONNX export + stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(dtype) + stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(dtype) + shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], dim=-1) + all_points = shifts.to(device) + return all_points + + def valid_flags(self, featmap_sizes, pad_shape, device="cuda"): + """Generate valid flags of points of multiple feature levels. + + Args: + featmap_sizes (list(tuple)): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + pad_shape (tuple(int)): The padded shape of the image, + arrange as (h, w). + device (str): The device where the anchors will be put on. + + Return: + list(torch.Tensor): Valid flags of points of multiple levels. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_flags = [] + for i in range(self.num_levels): + point_stride = self.strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = pad_shape[:2] + valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h) + valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w) + flags = self.single_level_valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device) + multi_level_flags.append(flags) + return multi_level_flags + + def single_level_valid_flags(self, featmap_size, valid_size, device="cuda"): + """Generate the valid flags of points of a single feature map. + + Args: + featmap_size (tuple[int]): The size of feature maps, arrange as + as (h, w). + valid_size (tuple[int]): The valid size of the feature maps. + The size arrange as as (h, w). + device (str, optional): The device where the flags will be put on. + Defaults to 'cuda'. + + Returns: + torch.Tensor: The valid flags of each points in a single level \ + feature map. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) + valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid + + def sparse_priors(self, prior_idxs, featmap_size, level_idx, dtype=torch.float32, device="cuda"): + """Generate sparse points according to the ``prior_idxs``. + + Args: + prior_idxs (Tensor): The index of corresponding anchors + in the feature map. + featmap_size (tuple[int]): feature map size arrange as (w, h). + level_idx (int): The level index of corresponding feature + map. + dtype (obj:`torch.dtype`): Date type of points. Defaults to + ``torch.float32``. + device (obj:`torch.device`): The device where the points is + located. + Returns: + Tensor: Anchor with shape (N, 2), N should be equal to + the length of ``prior_idxs``. And last dimension + 2 represent (coord_x, coord_y). + """ + height, width = featmap_size + x = (prior_idxs % width + self.offset) * self.strides[level_idx][0] + y = ((prior_idxs // width) % height + self.offset) * self.strides[level_idx][1] + prioris = torch.stack([x, y], 1).to(dtype) + prioris = prioris.to(device) + return prioris diff --git a/src/dinov2/eval/segmentation_m2f/core/box/__init__.py b/src/dinov2/eval/segmentation_m2f/core/box/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf35a613f81acd77ecab2dfb75a722fa8e5c0787 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/box/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .builder import * # noqa: F403 +from .samplers import MaskPseudoSampler # noqa: F403 diff --git a/src/dinov2/eval/segmentation_m2f/core/box/builder.py b/src/dinov2/eval/segmentation_m2f/core/box/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..9538c0de3db682c2b111b085a8a1ce321c76a9ff --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/box/builder.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.utils import Registry, build_from_cfg + +BBOX_SAMPLERS = Registry("bbox_sampler") +BBOX_CODERS = Registry("bbox_coder") + + +def build_sampler(cfg, **default_args): + """Builder of box sampler.""" + return build_from_cfg(cfg, BBOX_SAMPLERS, default_args) + + +def build_bbox_coder(cfg, **default_args): + """Builder of box coder.""" + return build_from_cfg(cfg, BBOX_CODERS, default_args) diff --git a/src/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py b/src/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19c363e3fabc365d92aeaf1e78189d710db279e9 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F403 diff --git a/src/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py b/src/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c45cec3ed7af5b49bb54b92d6e6bcf59b06b4c99 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from abc import ABCMeta, abstractmethod + +import torch + +from .sampling_result import SamplingResult + + +class BaseSampler(metaclass=ABCMeta): + """Base class of samplers.""" + + def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs): + self.num = num + self.pos_fraction = pos_fraction + self.neg_pos_ub = neg_pos_ub + self.add_gt_as_proposals = add_gt_as_proposals + self.pos_sampler = self + self.neg_sampler = self + + @abstractmethod + def _sample_pos(self, assign_result, num_expected, **kwargs): + """Sample positive samples.""" + pass + + @abstractmethod + def _sample_neg(self, assign_result, num_expected, **kwargs): + """Sample negative samples.""" + pass + + def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs): + """Sample positive and negative bboxes. + + This is a simple implementation of bbox sampling given candidates, + assigning results and ground truth bboxes. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + bboxes (Tensor): Boxes to be sampled from. + gt_bboxes (Tensor): Ground truth bboxes. + gt_labels (Tensor, optional): Class labels of ground truth bboxes. + + Returns: + :obj:`SamplingResult`: Sampling result. + + Example: + >>> from mmdet.core.bbox import RandomSampler + >>> from mmdet.core.bbox import AssignResult + >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes + >>> rng = ensure_rng(None) + >>> assign_result = AssignResult.random(rng=rng) + >>> bboxes = random_boxes(assign_result.num_preds, rng=rng) + >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng) + >>> gt_labels = None + >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, + >>> add_gt_as_proposals=False) + >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels) + """ + if len(bboxes.shape) < 2: + bboxes = bboxes[None, :] + + bboxes = bboxes[:, :4] + + gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8) + if self.add_gt_as_proposals and len(gt_bboxes) > 0: + if gt_labels is None: + raise ValueError("gt_labels must be given when add_gt_as_proposals is True") + bboxes = torch.cat([gt_bboxes, bboxes], dim=0) + assign_result.add_gt_(gt_labels) + gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) + gt_flags = torch.cat([gt_ones, gt_flags]) + + num_expected_pos = int(self.num * self.pos_fraction) + pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs) + # We found that sampled indices have duplicated items occasionally. + # (may be a bug of PyTorch) + pos_inds = pos_inds.unique() + num_sampled_pos = pos_inds.numel() + num_expected_neg = self.num - num_sampled_pos + if self.neg_pos_ub >= 0: + _pos = max(1, num_sampled_pos) + neg_upper_bound = int(self.neg_pos_ub * _pos) + if num_expected_neg > neg_upper_bound: + num_expected_neg = neg_upper_bound + neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs) + neg_inds = neg_inds.unique() + + sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags) + return sampling_result diff --git a/src/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py b/src/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..3e67ea61ed0fd65cca0addde1893a3c1e176bf15 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py + +import torch + +from ..builder import BBOX_SAMPLERS +from .base_sampler import BaseSampler +from .mask_sampling_result import MaskSamplingResult + + +@BBOX_SAMPLERS.register_module() +class MaskPseudoSampler(BaseSampler): + """A pseudo sampler that does not do sampling actually.""" + + def __init__(self, **kwargs): + pass + + def _sample_pos(self, **kwargs): + """Sample positive samples.""" + raise NotImplementedError + + def _sample_neg(self, **kwargs): + """Sample negative samples.""" + raise NotImplementedError + + def sample(self, assign_result, masks, gt_masks, **kwargs): + """Directly returns the positive and negative indices of samples. + + Args: + assign_result (:obj:`AssignResult`): Assigned results + masks (torch.Tensor): Bounding boxes + gt_masks (torch.Tensor): Ground truth boxes + Returns: + :obj:`SamplingResult`: sampler results + """ + pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8) + sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags) + return sampling_result diff --git a/src/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py b/src/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py new file mode 100644 index 0000000000000000000000000000000000000000..270ffd35a5f120dd0560a7fea7fe83ef0bab66bb --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py + +import torch + +from .sampling_result import SamplingResult + + +class MaskSamplingResult(SamplingResult): + """Mask sampling result.""" + + def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_masks = masks[pos_inds] + self.neg_masks = masks[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_masks.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_masks.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_masks = torch.empty_like(gt_masks) + else: + self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] + + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def masks(self): + """torch.Tensor: concatenated positive and negative boxes""" + return torch.cat([self.pos_masks, self.neg_masks]) + + def __nice__(self): + data = self.info.copy() + data["pos_masks"] = data.pop("pos_masks").shape + data["neg_masks"] = data.pop("neg_masks").shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = " " + ",\n ".join(parts) + return "{\n" + body + "\n}" + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + "pos_inds": self.pos_inds, + "neg_inds": self.neg_inds, + "pos_masks": self.pos_masks, + "neg_masks": self.neg_masks, + "pos_is_gt": self.pos_is_gt, + "num_gts": self.num_gts, + "pos_assigned_gt_inds": self.pos_assigned_gt_inds, + } diff --git a/src/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py b/src/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py new file mode 100644 index 0000000000000000000000000000000000000000..aaee3fe55aeb8c6da7edefbbd382d94b67b6a6b4 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch + + +class SamplingResult: + """Bbox sampling result. + + Example: + >>> # xdoctest: +IGNORE_WANT + >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random(rng=10) + >>> print(f'self = {self}') + self = + """ + + def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_bboxes = bboxes[pos_inds] + self.neg_bboxes = bboxes[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_bboxes.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_bboxes.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, 4) + + self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :] + + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def bboxes(self): + """torch.Tensor: concatenated positive and negative boxes""" + return torch.cat([self.pos_bboxes, self.neg_bboxes]) + + def to(self, device): + """Change the device of the data inplace. + + Example: + >>> self = SamplingResult.random() + >>> print(f'self = {self.to(None)}') + >>> # xdoctest: +REQUIRES(--gpu) + >>> print(f'self = {self.to(0)}') + """ + _dict = self.__dict__ + for key, value in _dict.items(): + if isinstance(value, torch.Tensor): + _dict[key] = value.to(device) + return self + + def __nice__(self): + data = self.info.copy() + data["pos_bboxes"] = data.pop("pos_bboxes").shape + data["neg_bboxes"] = data.pop("neg_bboxes").shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = " " + ",\n ".join(parts) + return "{\n" + body + "\n}" + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + "pos_inds": self.pos_inds, + "neg_inds": self.neg_inds, + "pos_bboxes": self.pos_bboxes, + "neg_bboxes": self.neg_bboxes, + "pos_is_gt": self.pos_is_gt, + "num_gts": self.num_gts, + "pos_assigned_gt_inds": self.pos_assigned_gt_inds, + } + + @classmethod + def random(cls, rng=None, **kwargs): + """ + Args: + rng (None | int | numpy.random.RandomState): seed or state. + kwargs (keyword arguments): + - num_preds: number of predicted boxes + - num_gts: number of true boxes + - p_ignore (float): probability of a predicted box assigned to \ + an ignored truth. + - p_assigned (float): probability of a predicted box not being \ + assigned. + - p_use_label (float | bool): with labels or not. + + Returns: + :obj:`SamplingResult`: Randomly generated sampling result. + + Example: + >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random() + >>> print(self.__dict__) + """ + from mmdet.core.bbox import demodata + from mmdet.core.bbox.assigners.assign_result import AssignResult + from mmdet.core.bbox.samplers.random_sampler import RandomSampler + + rng = demodata.ensure_rng(rng) + + # make probabalistic? + num = 32 + pos_fraction = 0.5 + neg_pos_ub = -1 + + assign_result = AssignResult.random(rng=rng, **kwargs) + + # Note we could just compute an assignment + bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng) + gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng) + + if rng.rand() > 0.2: + # sometimes algorithms squeeze their data, be robust to that + gt_bboxes = gt_bboxes.squeeze() + bboxes = bboxes.squeeze() + + if assign_result.labels is None: + gt_labels = None + else: + gt_labels = None + + if gt_labels is None: + add_gt_as_proposals = False + else: + add_gt_as_proposals = True # make probabalistic? + + sampler = RandomSampler( + num, pos_fraction, neg_pos_ub=neg_pos_ub, add_gt_as_proposals=add_gt_as_proposals, rng=rng + ) + self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels) + return self diff --git a/src/dinov2/eval/segmentation_m2f/core/utils/__init__.py b/src/dinov2/eval/segmentation_m2f/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6cdc9e19352f50bc2d5433c412ff71186c5df019 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dist_utils import reduce_mean +from .misc import add_prefix, multi_apply diff --git a/src/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py b/src/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7dfed42da821cd94e31b663d86b20b8f09799b30 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch.distributed as dist + + +def reduce_mean(tensor): + """ "Obtain the mean of tensor on different GPUs.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor diff --git a/src/dinov2/eval/segmentation_m2f/core/utils/misc.py b/src/dinov2/eval/segmentation_m2f/core/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e07579e7b182b62153e81fe637ffd0f3081ef2a3 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/core/utils/misc.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from functools import partial + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs diff --git a/src/dinov2/eval/segmentation_m2f/models/__init__.py b/src/dinov2/eval/segmentation_m2f/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed89bb0064d82b4360af020798eab3d2f5a47937 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .backbones import * # noqa: F403 +from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost +from .decode_heads import * # noqa: F403 +from .losses import * # noqa: F403 +from .plugins import * # noqa: F403 +from .segmentors import * # noqa: F403 diff --git a/src/dinov2/eval/segmentation_m2f/models/backbones/__init__.py b/src/dinov2/eval/segmentation_m2f/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bf73bcbcee710676f81cb6517ae787f4d61cc6 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/backbones/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .vit_adapter import ViTAdapter diff --git a/src/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py b/src/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..26bfdf8f6ae6c107d22d61985cce34d4b5ce275f --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py @@ -0,0 +1,442 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp + +from ...ops.modules import MSDeformAttn +from .drop_path import DropPath + + +def get_reference_points(spatial_shapes, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / H_ + ref_x = ref_x.reshape(-1)[None] / W_ + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] + return reference_points + + +def deform_inputs(x, patch_size): + bs, c, h, w = x.shape + spatial_shapes = torch.as_tensor( + [(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device + ) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device) + deform_inputs1 = [reference_points, spatial_shapes, level_start_index] + + spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device) + deform_inputs2 = [reference_points, spatial_shapes, level_start_index] + + return deform_inputs1, deform_inputs2 + + +class ConvFFN(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + n = N // 21 + x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous() + x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous() + x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous() + x1 = self.dwconv(x1).flatten(2).transpose(1, 2) + x2 = self.dwconv(x2).flatten(2).transpose(1, 2) + x3 = self.dwconv(x3).flatten(2).transpose(1, 2) + x = torch.cat([x1, x2, x3], dim=1) + return x + + +class Extractor(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + with_cffn=True, + cffn_ratio=0.25, + drop=0.0, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + with_cp=False, + ): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn( + d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio + ) + self.with_cffn = with_cffn + self.with_cp = with_cp + if with_cffn: + self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop) + self.ffn_norm = norm_layer(dim) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W): + def _inner_forward(query, feat): + + attn = self.attn( + self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None + ) + query = query + attn + + if self.with_cffn: + query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W)) + return query + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class Injector(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + init_values=0.0, + with_cp=False, + ): + super().__init__() + self.with_cp = with_cp + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn( + d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio + ) + self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + + def forward(self, query, reference_points, feat, spatial_shapes, level_start_index): + def _inner_forward(query, feat): + + attn = self.attn( + self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None + ) + return query + self.gamma * attn + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class InteractionBlock(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0.0, + drop_path=0.0, + with_cffn=True, + cffn_ratio=0.25, + init_values=0.0, + deform_ratio=1.0, + extra_extractor=False, + with_cp=False, + ): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cp=with_cp, + ) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + if extra_extractor: + self.extra_extractors = nn.Sequential( + *[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + for _ in range(2) + ] + ) + else: + self.extra_extractors = None + + def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2], + ) + for idx, blk in enumerate(blocks): + x = blk(x, H_toks, W_toks) + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + return x, c + + +class InteractionBlockWithCls(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0.0, + drop_path=0.0, + with_cffn=True, + cffn_ratio=0.25, + init_values=0.0, + deform_ratio=1.0, + extra_extractor=False, + with_cp=False, + ): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cp=with_cp, + ) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + if extra_extractor: + self.extra_extractors = nn.Sequential( + *[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + for _ in range(2) + ] + ) + else: + self.extra_extractors = None + + def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2], + ) + x = torch.cat((cls, x), dim=1) + for idx, blk in enumerate(blocks): + x = blk(x, H_toks, W_toks) + cls, x = ( + x[ + :, + :1, + ], + x[ + :, + 1:, + ], + ) + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + return x, c, cls + + +class SpatialPriorModule(nn.Module): + def __init__(self, inplanes=64, embed_dim=384, with_cp=False): + super().__init__() + self.with_cp = with_cp + + self.stem = nn.Sequential( + *[ + nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ] + ) + self.conv2 = nn.Sequential( + *[ + nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(2 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.conv3 = nn.Sequential( + *[ + nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.conv4 = nn.Sequential( + *[ + nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, x): + def _inner_forward(x): + c1 = self.stem(x) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + c1 = self.fc1(c1) + c2 = self.fc2(c2) + c3 = self.fc3(c3) + c4 = self.fc4(c4) + + bs, dim, _, _ = c1.shape + # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s + c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s + c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s + c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s + + return c1, c2, c3, c4 + + if self.with_cp and x.requires_grad: + outs = cp.checkpoint(_inner_forward, x) + else: + outs = _inner_forward(x) + return outs diff --git a/src/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py b/src/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..864eb8738c44652d12b979fc811503f21cbb00dd --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/src/dinov2/eval/segmentation_m2f/models/backbones/vit.py b/src/dinov2/eval/segmentation_m2f/models/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..8a147570451bd2fbd016ddfafbbfa33035cbd4f8 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/backbones/vit.py @@ -0,0 +1,552 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +"""Vision Transformer (ViT) in PyTorch. + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +import math +from functools import partial +from itertools import repeat +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.runner import BaseModule, load_checkpoint +from mmseg.ops import resize +from mmseg.utils import get_root_logger +from torch import Tensor + +from .drop_path import DropPath + + +def to_2tuple(x): + return tuple(repeat(x, 2)) + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + swiglu_hidden_features = int(2 * hidden_features / 3) + align_as = 8 + swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as + self.w1 = nn.Linear(in_features, swiglu_hidden_features) + self.w2 = nn.Linear(in_features, swiglu_hidden_features) + self.w3 = nn.Linear(swiglu_hidden_features, out_features) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.w1(x) + x2 = self.w2(x) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding.""" + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, H, W + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, H, W): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, H, W) -> Tensor: + from xformers.ops import memory_efficient_attention, unbind + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowedAttention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, window_size=14, pad_mode="constant" + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.window_size = window_size + self.pad_mode = pad_mode + + def forward(self, x, H, W): + B, N, C = x.shape + N_ = self.window_size * self.window_size + H_ = math.ceil(H / self.window_size) * self.window_size + W_ = math.ceil(W / self.window_size) * self.window_size + + qkv = self.qkv(x) # [B, N, C] + qkv = qkv.transpose(1, 2).reshape(B, C * 3, H, W) # [B, C, H, W] + qkv = F.pad(qkv, [0, W_ - W, 0, H_ - H], mode=self.pad_mode) + + qkv = F.unfold( + qkv, kernel_size=(self.window_size, self.window_size), stride=(self.window_size, self.window_size) + ) + B, C_kw_kw, L = qkv.shape # L - the num of windows + qkv = qkv.reshape(B, C * 3, N_, L).permute(0, 3, 2, 1) # [B, L, N_, C] + qkv = qkv.reshape(B, L, N_, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + # q,k,v [B, L, num_head, N_, C/num_head] + attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_] + # if self.mask: + # attn = attn * mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) # [B, L, num_head, N_, N_] + # attn @ v = [B, L, num_head, N_, C/num_head] + x = (attn @ v).permute(0, 2, 4, 3, 1).reshape(B, C_kw_kw // 3, L) + + x = F.fold( + x, + output_size=(H_, W_), + kernel_size=(self.window_size, self.window_size), + stride=(self.window_size, self.window_size), + ) # [B, C, H_, W_] + x = x[:, :, :H, :W].reshape(B, C, N).transpose(-1, -2) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +# class WindowedAttention(nn.Module): +# def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., window_size=14, pad_mode="constant"): +# super().__init__() +# self.num_heads = num_heads +# head_dim = dim // num_heads +# self.scale = head_dim ** -0.5 +# +# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) +# self.attn_drop = nn.Dropout(attn_drop) +# self.proj = nn.Linear(dim, dim) +# self.proj_drop = nn.Dropout(proj_drop) +# self.window_size = window_size +# self.pad_mode = pad_mode +# +# def forward(self, x, H, W): +# B, N, C = x.shape +# +# N_ = self.window_size * self.window_size +# H_ = math.ceil(H / self.window_size) * self.window_size +# W_ = math.ceil(W / self.window_size) * self.window_size +# x = x.view(B, H, W, C) +# x = F.pad(x, [0, 0, 0, W_ - W, 0, H_- H], mode=self.pad_mode) +# +# x = window_partition(x, window_size=self.window_size)# nW*B, window_size, window_size, C +# x = x.view(-1, N_, C) +# +# qkv = self.qkv(x).view(-1, N_, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) +# q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) +# attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_] +# attn = attn.softmax(dim=-1) +# attn = self.attn_drop(attn) # [B, L, num_head, N_, N_] +# x = (attn @ v).transpose(1, 2).reshape(-1, self.window_size, self.window_size, C) +# +# x = window_reverse(x, self.window_size, H_, W_) +# x = x[:, :H, :W, :].reshape(B, N, C).contiguous() +# x = self.proj(x) +# x = self.proj_drop(x) +# return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + windowed=False, + window_size=14, + pad_mode="constant", + layer_scale=False, + with_cp=False, + ffn_layer=Mlp, + memeff=False, + ): + super().__init__() + self.with_cp = with_cp + self.norm1 = norm_layer(dim) + if windowed: + self.attn = WindowedAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + pad_mode=pad_mode, + ) + elif memeff: + self.attn = MemEffAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop + ) + else: + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.layer_scale = layer_scale + if layer_scale: + self.gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True) + self.gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True) + + def forward(self, x, H, W): + def _inner_forward(x): + if self.layer_scale: + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class TIMMVisionTransformer(BaseModule): + """Vision Transformer. + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + layer_scale=True, + embed_layer=PatchEmbed, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + window_attn=False, + window_size=14, + pretrained=None, + with_cp=False, + pre_norm=False, + ffn_type="mlp", + memeff=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + pretrained: (str): pretrained path + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + self.norm_layer = norm_layer + self.act_layer = act_layer + self.pretrain_size = img_size + self.drop_path_rate = drop_path_rate + self.drop_rate = drop_rate + self.patch_size = patch_size + + window_attn = [window_attn] * depth if not isinstance(window_attn, list) else window_attn + window_size = [window_size] * depth if not isinstance(window_size, list) else window_size + logging.info("window attention:", window_attn) + logging.info("window size:", window_size) + logging.info("layer scale:", layer_scale) + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm + ) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + ffn_types = {"mlp": Mlp, "swiglu": SwiGLUFFN} + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + windowed=window_attn[i], + window_size=window_size[i], + layer_scale=layer_scale, + with_cp=with_cp, + ffn_layer=ffn_types[ffn_type], + memeff=memeff, + ) + for i in range(depth) + ] + ) + + # self.norm = norm_layer(embed_dim) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # For CLIP + if pre_norm: + norm_pre = norm_layer(embed_dim) + self.norm_pre = norm_pre + else: + self.norm_pre = nn.Identity() + self.init_weights(pretrained) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location="cpu", strict=False, logger=logger) + + def forward_features(self, x): + x, H, W = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_token, x), dim=1) + x = self.pos_drop(x + self.pos_embed) + + # For CLIP + x = self.norm_pre(x) + + for blk in self.blocks: + x = blk(x, H, W) + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + return x + + @staticmethod + def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): + """Resize pos_embed weights. + + Resize pos_embed using bicubic interpolate method. + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shpae (tuple): Tuple for (downsampled input image height, + downsampled input image width). + pos_shape (tuple): The resolution of downsampled origin training + image. + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C] + """ + assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]" + pos_h, pos_w = pos_shape + # keep dim for easy deployment + cls_token_weight = pos_embed[:, 0:1] + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :] + pos_embed_weight = pos_embed_weight.reshape(1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = resize(pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) + return pos_embed diff --git a/src/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py b/src/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc4f0f65e04ed764464d141607b3b2073220f6b --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.models.builder import BACKBONES +from torch.nn.init import normal_ + +from ...ops.modules import MSDeformAttn +from .adapter_modules import InteractionBlock, InteractionBlockWithCls, SpatialPriorModule, deform_inputs +from .vit import TIMMVisionTransformer + + +@BACKBONES.register_module() +class ViTAdapter(TIMMVisionTransformer): + def __init__( + self, + pretrain_size=224, + num_heads=12, + conv_inplane=64, + n_points=4, + deform_num_heads=6, + init_values=0.0, + interaction_indexes=None, + with_cffn=True, + cffn_ratio=0.25, + deform_ratio=1.0, + add_vit_feature=True, + pretrained=None, + use_extra_extractor=True, + freeze_vit=False, + use_cls=True, + with_cp=False, + *args, + **kwargs + ): + + super().__init__(num_heads=num_heads, pretrained=pretrained, with_cp=with_cp, *args, **kwargs) + if freeze_vit: + for param in self.parameters(): + param.requires_grad = False + + # self.num_classes = 80 + self.use_cls = use_cls + if not self.use_cls: + self.cls_token = None + self.num_block = len(self.blocks) + self.pretrain_size = (pretrain_size, pretrain_size) + self.interaction_indexes = interaction_indexes + self.add_vit_feature = add_vit_feature + embed_dim = self.embed_dim + + block_fn = InteractionBlockWithCls if use_cls else InteractionBlock + + self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) + self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False) + self.interactions = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=deform_num_heads, + n_points=n_points, + init_values=init_values, + drop_path=self.drop_path_rate, + norm_layer=self.norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor), + with_cp=with_cp, + ) + for i in range(len(interaction_indexes)) + ] + ) + self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) + self.norm1 = nn.SyncBatchNorm(embed_dim) + self.norm2 = nn.SyncBatchNorm(embed_dim) + self.norm3 = nn.SyncBatchNorm(embed_dim) + self.norm4 = nn.SyncBatchNorm(embed_dim) + + self.up.apply(self._init_weights) + self.spm.apply(self._init_weights) + self.interactions.apply(self._init_weights) + self.apply(self._init_deform_weights) + normal_(self.level_embed) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def _get_pos_embed(self, pos_embed, H, W): + pos_embed = pos_embed.reshape( + 1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1 + ).permute(0, 3, 1, 2) + pos_embed = ( + F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False) + .reshape(1, -1, H * W) + .permute(0, 2, 1) + ) + return pos_embed + + def _init_deform_weights(self, m): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + + def _add_level_embed(self, c2, c3, c4): + c2 = c2 + self.level_embed[0] + c3 = c3 + self.level_embed[1] + c4 = c4 + self.level_embed[2] + return c2, c3, c4 + + def forward(self, x): + deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size) + + # SPM forward + c1, c2, c3, c4 = self.spm(x) + c2, c3, c4 = self._add_level_embed(c2, c3, c4) + c = torch.cat([c2, c3, c4], dim=1) + + # Patch Embedding forward + H_c, W_c = x.shape[2] // 16, x.shape[3] // 16 + x, H_toks, W_toks = self.patch_embed(x) + # print("H_toks, W_toks =", H_toks, W_toks) + bs, n, dim = x.shape + pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H_toks, W_toks) + if self.use_cls: + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_token, x), dim=1) + pos_embed = torch.cat((self.pos_embed[:, :1], pos_embed), dim=1) + x = self.pos_drop(x + pos_embed) + # For CLIP + x = self.norm_pre(x) + + # Interaction + if self.use_cls: + cls, x = ( + x[ + :, + :1, + ], + x[ + :, + 1:, + ], + ) + outs = list() + for i, layer in enumerate(self.interactions): + indexes = self.interaction_indexes[i] + if self.use_cls: + x, c, cls = layer( + x, + c, + cls, + self.blocks[indexes[0] : indexes[-1] + 1], + deform_inputs1, + deform_inputs2, + H_c, + W_c, + H_toks, + W_toks, + ) + else: + x, c = layer( + x, + c, + self.blocks[indexes[0] : indexes[-1] + 1], + deform_inputs1, + deform_inputs2, + H_c, + W_c, + H_toks, + W_toks, + ) + outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous()) + + # Split & Reshape + c2 = c[:, 0 : c2.size(1), :] + c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :] + c4 = c[:, c2.size(1) + c3.size(1) :, :] + + c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous() + c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous() + c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous() + c1 = self.up(c2) + c1 + + if self.add_vit_feature: + x1, x2, x3, x4 = outs + + x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False) + x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False) + x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False) + x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False) + # print(c1.shape, c2.shape, c3.shape, c4.shape, x1.shape, x2.shape, x3.shape, x4.shape, H_c, H_toks) + c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 + + # Final Norm + f1 = self.norm1(c1) + f2 = self.norm2(c2) + f3 = self.norm3(c3) + f4 = self.norm4(c4) + return [f1, f2, f3, f4] diff --git a/src/dinov2/eval/segmentation_m2f/models/builder.py b/src/dinov2/eval/segmentation_m2f/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..d7cf7b919f6b0e8e00bde45bc244d9c29a36fed6 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/builder.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.utils import Registry + +TRANSFORMER = Registry("Transformer") +MASK_ASSIGNERS = Registry("mask_assigner") +MATCH_COST = Registry("match_cost") + + +def build_match_cost(cfg): + """Build Match Cost.""" + return MATCH_COST.build(cfg) + + +def build_assigner(cfg): + """Build Assigner.""" + return MASK_ASSIGNERS.build(cfg) + + +def build_transformer(cfg): + """Build Transformer.""" + return TRANSFORMER.build(cfg) diff --git a/src/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py b/src/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01f08b88950750337781fc671adfea2a935ea8fe --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .mask2former_head import Mask2FormerHead diff --git a/src/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py b/src/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d1705fc444fa8d1583d88fca36d7fe1e060db9e7 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py @@ -0,0 +1,544 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init +from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence +from mmcv.ops import point_sample +from mmcv.runner import ModuleList, force_fp32 +from mmseg.models.builder import HEADS, build_loss +from mmseg.models.decode_heads.decode_head import BaseDecodeHead + +from ...core import build_sampler, multi_apply, reduce_mean +from ..builder import build_assigner +from ..utils import get_uncertain_point_coords_with_randomness + + +@HEADS.register_module() +class Mask2FormerHead(BaseDecodeHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer decoder. + pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel + decoder. Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of tranformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder. Defaults to None. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder position encoding. Defaults to None. + loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification + loss. Defaults to None. + loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. + Defaults to None. + loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. + Defaults to None. + train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of + Mask2Former head. + test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of + Mask2Former head. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + in_channels, + feat_channels, + out_channels, + num_things_classes=80, + num_stuff_classes=53, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=None, + enforce_decoder_input_project=False, + transformer_decoder=None, + positional_encoding=None, + loss_cls=None, + loss_mask=None, + loss_dice=None, + train_cfg=None, + test_cfg=None, + init_cfg=None, + **kwargs, + ): + super(Mask2FormerHead, self).__init__( + in_channels=in_channels, + channels=feat_channels, + num_classes=(num_things_classes + num_stuff_classes), + init_cfg=init_cfg, + input_transform="multiple_select", + **kwargs, + ) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.transformerlayers.attn_cfgs.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update(in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels) + self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] + self.transformer_decoder = build_transformer_layer_sequence(transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if self.decoder_embed_dims != feat_channels or enforce_decoder_input_project: + self.decoder_input_projs.append(Conv2d(feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = build_positional_encoding(positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels) + + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), + nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), + nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels), + ) + self.conv_seg = None # fix a bug here (conv_seg is not used) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + self.sampler = build_sampler(self.train_cfg.sampler, context=self) + self.num_points = self.train_cfg.get("num_points", 12544) + self.oversample_ratio = self.train_cfg.get("oversample_ratio", 3.0) + self.importance_sample_ratio = self.train_cfg.get("importance_sample_ratio", 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.loss_dice = build_loss(loss_dice) + + def init_weights(self): + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas): + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape [num_queries, + cls_out_channels]. + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape [num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for all + images. Each with shape (n, ), n is the sum of number of stuff + type and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[list[Tensor]]: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images. + Each with shape [num_queries, ]. + - label_weights_list (list[Tensor]): Label weights of all + images.Each with shape [num_queries, ]. + - mask_targets_list (list[Tensor]): Mask targets of all images. + Each with shape [num_queries, h, w]. + - mask_weights_list (list[Tensor]): Mask weights of all images. + Each with shape [num_queries, ]. + - num_total_pos (int): Number of positive samples in all + images. + - num_total_neg (int): Number of negative samples in all + images. + """ + ( + labels_list, + label_weights_list, + mask_targets_list, + mask_weights_list, + pos_inds_list, + neg_inds_list, + ) = multi_apply( + self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas + ) + + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas): + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_labels (Tensor): Ground truth class indices for one image with + shape (num_gts, ). + gt_masks (Tensor): Ground truth mask for each image, each with + shape (num_gts, h, w). + img_metas (dict): Image informtation. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each \ + image. + - neg_inds (Tensor): Sampled negative indices for each \ + image. + """ + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1) + + # assign and sample + assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas) + sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((self.num_queries,), self.num_classes, dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones((self.num_queries,)) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((self.num_queries,)) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds) + + def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas): + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + gt_labels_list (list[Tensor]): Ground truth class indices for each + image, each with shape (num_gts, ). + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (num_gts, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single \ + decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + ( + labels_list, + label_weights_list, + mask_targets_list, + mask_weights_list, + num_total_pos, + num_total_neg, + ) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas) + # shape (batch_size, num_queries) + labels = torch.stack(labels_list, dim=0) + # shape (batch_size, num_queries) + label_weights = torch.stack(label_weights_list, dim=0) + # shape (num_total_gts, h, w) + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_tensor(self.class_weight) + loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio + ) + # shape (num_total_gts, h, w) -> (num_total_gts, num_points) + mask_point_targets = point_sample(mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1) + + # dice loss + loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks) + + # mask loss + # shape (num_queries, num_points) -> (num_queries * num_points, ) + mask_point_preds = mask_point_preds.reshape(-1, 1) + # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) + mask_point_targets = mask_point_targets.reshape(-1) + loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points) + + return loss_cls, loss_mask, loss_dice + + @force_fp32(apply_to=("all_cls_scores", "all_mask_preds")) + def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas): + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape [num_decoder, batch_size, num_queries, + cls_out_channels]. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape [num_decoder, batch_size, num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (n, ). n is the sum of number of stuff type + and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image with + shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + losses_cls, losses_mask, losses_dice = multi_apply( + self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list + ) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict["loss_cls"] = losses_cls[-1] + loss_dict["loss_mask"] = losses_mask[-1] + loss_dict["loss_dice"] = losses_dice[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): + loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i + loss_dict[f"d{num_dec_layer}.loss_mask"] = loss_mask_i + loss_dict[f"d{num_dec_layer}.loss_dice"] = loss_dice_i + num_dec_layer += 1 + return loss_dict + + def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (num_queries, batch_size, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + decoder_out = decoder_out.transpose(0, 1) + # shape (num_queries, batch_size, c) + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_feature) + attn_mask = F.interpolate(mask_pred, attn_mask_target_size, mode="bilinear", align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat((1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask + + def forward(self, feats, img_metas): + """Forward function. + + Args: + feats (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + + Returns: + tuple: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_size = len(img_metas) + mask_features, multi_scale_memorys = self.pixel_decoder(feats) + # multi_scale_memorys (from low resolution to high resolution) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + decoder_input = decoder_input.flatten(2).permute(2, 0, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + mask = decoder_input.new_zeros((batch_size,) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding(mask) + decoder_positional_encoding = decoder_positional_encoding.flatten(2).permute(2, 0, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + # shape (num_queries, c) -> (num_queries, batch_size, c) + query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, batch_size, 1)) + query_embed = self.query_embed.weight.unsqueeze(1).repeat((1, batch_size, 1)) + + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask = self.forward_head(query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + attn_masks = [attn_mask, None] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + attn_masks=attn_masks, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None, + ) + cls_pred, mask_pred, attn_mask = self.forward_head( + query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:] + ) + + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + return cls_pred_list, mask_pred_list + + def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, gt_masks): + """Forward function for training mode. + + Args: + x (list[Tensor]): Multi-level features from the upstream network, + each is a 4D-tensor. + img_metas (list[Dict]): List of image information. + gt_semantic_seg (list[tensor]):Each element is the ground truth + of semantic segmentation with the shape (N, H, W). + train_cfg (dict): The training config, which not been used in + maskformer. + gt_labels (list[Tensor]): Each element is ground truth labels of + each box, shape (num_gts,). + gt_masks (list[BitmapMasks]): Each element is masks of instances + of a image, shape (num_gts, h, w). + + Returns: + losses (dict[str, Tensor]): a dictionary of loss components + """ + + # forward + all_cls_scores, all_mask_preds = self(x, img_metas) + + # loss + losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas) + + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Test segment without test-time aumengtation. + + Only the output of last decoder layers was used. + + Args: + inputs (list[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + test_cfg (dict): Testing config. + + Returns: + seg_mask (Tensor): Predicted semantic segmentation logits. + """ + all_cls_scores, all_mask_preds = self(inputs, img_metas) + cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1] + ori_h, ori_w, _ = img_metas[0]["ori_shape"] + + # semantic inference + cls_score = F.softmax(cls_score, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + seg_mask = torch.einsum("bqc,bqhw->bchw", cls_score, mask_pred) + return seg_mask diff --git a/src/dinov2/eval/segmentation_m2f/models/losses/__init__.py b/src/dinov2/eval/segmentation_m2f/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..229a887817372f4991b32354180592cfb236d728 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/losses/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .cross_entropy_loss import CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy +from .dice_loss import DiceLoss +from .match_costs import ClassificationCost, CrossEntropyLossCost, DiceCost diff --git a/src/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py b/src/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1f9dd4aa52ebe94cc527db36b1c7fa2f53813e --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.models.builder import LOSSES +from mmseg.models.losses.utils import get_class_weight, weight_reduce_loss + + +def cross_entropy( + pred, + label, + weight=None, + class_weight=None, + reduction="mean", + avg_factor=None, + ignore_index=-100, + avg_non_ignore=False, +): + """cross_entropy. The wrapper function for :func:`F.cross_entropy` + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + Default: None. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. + Options are 'none', 'mean' and 'sum'. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Default: None. + ignore_index (int): Specifies a target value that is ignored and + does not contribute to the input gradients. When + ``avg_non_ignore `` is ``True``, and the ``reduction`` is + ``''mean''``, the loss is averaged over non-ignored targets. + Defaults: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + # class_weight is a manual rescaling weight given to each class. + # If given, has to be a Tensor of size C element-wise losses + loss = F.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index) + + # apply weights and do the reduction + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and avg_non_ignore and reduction == "mean": + avg_factor = label.numel() - (label == ignore_index).sum().item() + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights = bin_label_weights * valid_mask + + return bin_labels, bin_label_weights, valid_mask + + +def binary_cross_entropy( + pred, + label, + weight=None, + reduction="mean", + avg_factor=None, + class_weight=None, + ignore_index=-100, + avg_non_ignore=False, + **kwargs, +): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + Note: In bce loss, label < 0 is invalid. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + + Returns: + torch.Tensor: The calculated loss + """ + if pred.size(1) == 1: + # For binary class segmentation, the shape of pred is + # [N, 1, H, W] and that of label is [N, H, W]. + assert label.max() <= 1, "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes" + pred = pred.squeeze() + if pred.dim() != label.dim(): + assert (pred.dim() == 2 and label.dim() == 1) or (pred.dim() == 4 and label.dim() == 3), ( + "Only pred shape [N, C], label shape [N] or pred shape [N, C, " "H, W], label shape [N, H, W] are supported" + ) + # `weight` returned from `_expand_onehot_labels` + # has been treated for valid (non-ignore) pixels + label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.shape, ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + weight = weight * valid_mask + else: + weight = valid_mask + # average loss over non-ignored and valid elements + if reduction == "mean" and avg_factor is None and avg_non_ignore: + avg_factor = valid_mask.sum().item() + + loss = F.binary_cross_entropy_with_logits(pred, label.float(), pos_weight=class_weight, reduction="none") + # do the reduction for the weighted loss + loss = weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy( + pred, target, label, reduction="mean", avg_factor=None, class_weight=None, ignore_index=None, **kwargs +): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask' + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + """ + assert ignore_index is None, "BCE loss does not support ignore_index" + assert reduction == "mean" and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits(pred_slice, target, weight=class_weight, reduction="mean")[None] + + +@LOSSES.register_module(force=True) +class CrossEntropyLoss(nn.Module): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + def __init__( + self, + use_sigmoid=False, + use_mask=False, + reduction="mean", + class_weight=None, + loss_weight=1.0, + loss_name="loss_ce", + avg_non_ignore=False, + ): + super(CrossEntropyLoss, self).__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self.avg_non_ignore = avg_non_ignore + if not self.avg_non_ignore and self.reduction == "mean": + warnings.warn( + "Default ``avg_non_ignore`` is False, if you would like to " + "ignore the certain label and average loss over non-ignore " + "labels, which is the same with PyTorch official " + "cross_entropy, set ``avg_non_ignore=True``." + ) + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + self._loss_name = loss_name + + def extra_repr(self): + """Extra repr.""" + s = f"avg_non_ignore={self.avg_non_ignore}" + return s + + def forward( + self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, ignore_index=-100, **kwargs + ): + """Forward function.""" + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + # Note: for BCE loss, label < 0 is invalid. + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + avg_non_ignore=self.avg_non_ignore, + ignore_index=ignore_index, + **kwargs, + ) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/src/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py b/src/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc5ba893c502861032ed531283f225e183eb693 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from mmseg.models.builder import LOSSES +from mmseg.models.losses.utils import weight_reduce_loss + + +def dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None): + """Calculate dice loss, which is proposed in + `V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation `_. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + + input = pred.flatten(1) + target = target.flatten(1).float() + + a = torch.sum(input * target, 1) + b = torch.sum(input * input, 1) + eps + c = torch.sum(target * target, 1) + eps + d = (2 * a) / (b + c) + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +def naive_dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None): + """Calculate naive dice loss, the coefficient in the denominator is the + first power instead of the second power. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + input = pred.flatten(1) + target = target.flatten(1).float() + + a = torch.sum(input * target, 1) + b = torch.sum(input, 1) + c = torch.sum(target, 1) + d = (2 * a + eps) / (b + c + eps) + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@LOSSES.register_module(force=True) +class DiceLoss(nn.Module): + def __init__(self, use_sigmoid=True, activate=True, reduction="mean", naive_dice=False, loss_weight=1.0, eps=1e-3): + """Dice Loss, there are two forms of dice loss is supported: + + - the one proposed in `V-Net: Fully Convolutional Neural + Networks for Volumetric Medical Image Segmentation + `_. + - the dice loss in which the power of the number in the + denominator is the first power instead of the second + power. + + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + activate (bool): Whether to activate the predictions inside, + this will disable the inside sigmoid operation. + Defaults to True. + reduction (str, optional): The method used + to reduce the loss. Options are "none", + "mean" and "sum". Defaults to 'mean'. + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power.Defaults to False. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + eps (float): Avoid dividing by zero. Defaults to 1e-3. + """ + + super(DiceLoss, self).__init__() + self.use_sigmoid = use_sigmoid + self.reduction = reduction + self.naive_dice = naive_dice + self.loss_weight = loss_weight + self.eps = eps + self.activate = activate + + def forward(self, pred, target, weight=None, reduction_override=None, avg_factor=None): + """Forward function. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *). + target (torch.Tensor): The label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + + if self.activate: + if self.use_sigmoid: + pred = pred.sigmoid() + else: + raise NotImplementedError + + if self.naive_dice: + loss = self.loss_weight * naive_dice_loss( + pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor + ) + else: + loss = self.loss_weight * dice_loss( + pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor + ) + + return loss diff --git a/src/dinov2/eval/segmentation_m2f/models/losses/match_costs.py b/src/dinov2/eval/segmentation_m2f/models/losses/match_costs.py new file mode 100644 index 0000000000000000000000000000000000000000..4917d2a939c01398dd49c0d90b06f4c37d283ce0 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/losses/match_costs.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +from ..builder import MATCH_COST + + +@MATCH_COST.register_module() +class ClassificationCost: + """ClsSoftmaxCost.Borrow from + mmdet.core.bbox.match_costs.match_cost.ClassificationCost. + + Args: + weight (int | float, optional): loss_weight + + Examples: + >>> import torch + >>> self = ClassificationCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3430, -0.3525, -0.3045], + [-0.3077, -0.2931, -0.3992], + [-0.3664, -0.3455, -0.2881], + [-0.3343, -0.2701, -0.3956]]) + """ + + def __init__(self, weight=1.0): + self.weight = weight + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + # Following the official DETR repo, contrary to the loss that + # NLL is used, we approximate it in 1 - cls_score[gt_label]. + # The 1 is a constant that doesn't change the matching, + # so it can be omitted. + cls_score = cls_pred.softmax(-1) + cls_cost = -cls_score[:, gt_labels] + return cls_cost * self.weight + + +@MATCH_COST.register_module() +class DiceCost: + """Cost of mask assignments based on dice losses. + + Args: + weight (int | float, optional): loss_weight. Defaults to 1. + pred_act (bool, optional): Whether to apply sigmoid to mask_pred. + Defaults to False. + eps (float, optional): default 1e-12. + """ + + def __init__(self, weight=1.0, pred_act=False, eps=1e-3): + self.weight = weight + self.pred_act = pred_act + self.eps = eps + + def binary_mask_dice_loss(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction in shape (N1, H, W). + gt_masks (Tensor): Ground truth in shape (N2, H, W) + store 0 or 1, 0 for negative class and 1 for + positive class. + + Returns: + Tensor: Dice cost matrix in shape (N1, N2). + """ + mask_preds = mask_preds.reshape((mask_preds.shape[0], -1)) + gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float() + numerator = 2 * torch.einsum("nc,mc->nm", mask_preds, gt_masks) + denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :] + loss = 1 - (numerator + self.eps) / (denominator + self.eps) + return loss + + def __call__(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction logits in shape (N1, H, W). + gt_masks (Tensor): Ground truth in shape (N2, H, W). + + Returns: + Tensor: Dice cost matrix in shape (N1, N2). + """ + if self.pred_act: + mask_preds = mask_preds.sigmoid() + dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks) + return dice_cost * self.weight + + +@MATCH_COST.register_module() +class CrossEntropyLossCost: + """CrossEntropyLossCost. + + Args: + weight (int | float, optional): loss weight. Defaults to 1. + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to True. + """ + + def __init__(self, weight=1.0, use_sigmoid=True): + assert use_sigmoid, "use_sigmoid = False is not supported yet." + self.weight = weight + self.use_sigmoid = use_sigmoid + + def _binary_cross_entropy(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): The prediction with shape (num_query, 1, *) or + (num_query, *). + gt_labels (Tensor): The learning label of prediction with + shape (num_gt, *). + Returns: + Tensor: Cross entropy cost matrix in shape (num_query, num_gt). + """ + cls_pred = cls_pred.flatten(1).float() + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + pos = F.binary_cross_entropy_with_logits(cls_pred, torch.ones_like(cls_pred), reduction="none") + neg = F.binary_cross_entropy_with_logits(cls_pred, torch.zeros_like(cls_pred), reduction="none") + cls_cost = torch.einsum("nc,mc->nm", pos, gt_labels) + torch.einsum("nc,mc->nm", neg, 1 - gt_labels) + cls_cost = cls_cost / n + + return cls_cost + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits. + gt_labels (Tensor): Labels. + Returns: + Tensor: Cross entropy cost matrix with weight in + shape (num_query, num_gt). + """ + if self.use_sigmoid: + cls_cost = self._binary_cross_entropy(cls_pred, gt_labels) + else: + raise NotImplementedError + + return cls_cost * self.weight diff --git a/src/dinov2/eval/segmentation_m2f/models/plugins/__init__.py b/src/dinov2/eval/segmentation_m2f/models/plugins/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a60db4de31238cb38e078683e5ca265839fe60 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/plugins/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder diff --git a/src/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py b/src/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..db1947175917f73f3f24184cb09c78e092d46ef8 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init, normal_init, xavier_init +from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence +from mmcv.runner import BaseModule, ModuleList + +from ...core.anchor import MlvlPointGenerator +from ..utils.transformer import MultiScaleDeformableAttention + + +@PLUGIN_LAYERS.register_module() +class MSDeformAttnPixelDecoder(BaseModule): + """Pixel decoder with multi-scale deformable attention. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + strides (list[int] | tuple[int]): Output strides of feature from + backbone. + feat_channels (int): Number of channels for feature. + out_channels (int): Number of channels for output. + num_outs (int): Number of output scales. + norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer + encoder. Defaults to `DetrTransformerEncoder`. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer encoder position encoding. Defaults to + dict(type='SinePositionalEncoding', num_feats=128, + normalize=True). + init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict. + """ + + def __init__( + self, + in_channels=[256, 512, 1024, 2048], + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_outs=3, + norm_cfg=dict(type="GN", num_groups=32), + act_cfg=dict(type="ReLU"), + encoder=dict( + type="DetrTransformerEncoder", + num_layers=6, + transformerlayers=dict( + type="BaseTransformerLayer", + attn_cfgs=dict( + type="MultiScaleDeformableAttention", + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None, + ), + feedforward_channels=1024, + ffn_dropout=0.0, + operation_order=("self_attn", "norm", "ffn", "norm"), + ), + init_cfg=None, + ), + positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self.strides = strides + self.num_input_levels = len(in_channels) + self.num_encoder_levels = encoder.transformerlayers.attn_cfgs.num_levels + assert self.num_encoder_levels >= 1, "num_levels in attn_cfgs must be at least one" + input_conv_list = [] + # from top to down (low to high resolution) + for i in range(self.num_input_levels - 1, self.num_input_levels - self.num_encoder_levels - 1, -1): + input_conv = ConvModule( + in_channels[i], feat_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=None, bias=True + ) + input_conv_list.append(input_conv) + self.input_convs = ModuleList(input_conv_list) + + self.encoder = build_transformer_layer_sequence(encoder) + self.postional_encoding = build_positional_encoding(positional_encoding) + # high resolution to low resolution + self.level_encoding = nn.Embedding(self.num_encoder_levels, feat_channels) + + # fpn-like structure + self.lateral_convs = ModuleList() + self.output_convs = ModuleList() + self.use_bias = norm_cfg is None + # from top to down (low to high resolution) + # fpn for the rest features that didn't pass in encoder + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1): + lateral_conv = ConvModule( + in_channels[i], feat_channels, kernel_size=1, bias=self.use_bias, norm_cfg=norm_cfg, act_cfg=None + ) + output_conv = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + self.lateral_convs.append(lateral_conv) + self.output_convs.append(output_conv) + + self.mask_feature = Conv2d(feat_channels, out_channels, kernel_size=1, stride=1, padding=0) + + self.num_outs = num_outs + self.point_generator = MlvlPointGenerator(strides) + + def init_weights(self): + """Initialize weights.""" + for i in range(0, self.num_encoder_levels): + xavier_init(self.input_convs[i].conv, gain=1, bias=0, distribution="uniform") + + for i in range(0, self.num_input_levels - self.num_encoder_levels): + caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) + caffe2_xavier_init(self.output_convs[i].conv, bias=0) + + caffe2_xavier_init(self.mask_feature, bias=0) + + normal_init(self.level_encoding, mean=0, std=1) + for p in self.encoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + # init_weights defined in MultiScaleDeformableAttention + for layer in self.encoder.layers: + for attn in layer.attentions: + if isinstance(attn, MultiScaleDeformableAttention): + attn.init_weights() + + def forward(self, feats): + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + + Returns: + tuple: A tuple containing the following: + + - mask_feature (Tensor): shape (batch_size, c, h, w). + - multi_scale_features (list[Tensor]): Multi scale \ + features, each in shape (batch_size, c, h, w). + """ + # generate padding mask for each level, for each image + batch_size = feats[0].shape[0] + encoder_input_list = [] + padding_mask_list = [] + level_positional_encoding_list = [] + spatial_shapes = [] + reference_points_list = [] + for i in range(self.num_encoder_levels): + level_idx = self.num_input_levels - i - 1 + feat = feats[level_idx] + feat_projected = self.input_convs[i](feat) + h, w = feat.shape[-2:] + + # no padding + padding_mask_resized = feat.new_zeros((batch_size,) + feat.shape[-2:], dtype=torch.bool) + pos_embed = self.postional_encoding(padding_mask_resized) + level_embed = self.level_encoding.weight[i] + level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed + # (h_i * w_i, 2) + reference_points = self.point_generator.single_level_grid_priors( + feat.shape[-2:], level_idx, device=feat.device + ) + # normalize + factor = feat.new_tensor([[w, h]]) * self.strides[level_idx] + reference_points = reference_points / factor + + # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c) + feat_projected = feat_projected.flatten(2).permute(2, 0, 1) + level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1) + padding_mask_resized = padding_mask_resized.flatten(1) + + encoder_input_list.append(feat_projected) + padding_mask_list.append(padding_mask_resized) + level_positional_encoding_list.append(level_pos_embed) + spatial_shapes.append(feat.shape[-2:]) + reference_points_list.append(reference_points) + # shape (batch_size, total_num_query), + # total_num_query=sum([., h_i * w_i,.]) + padding_masks = torch.cat(padding_mask_list, dim=1) + # shape (total_num_query, batch_size, c) + encoder_inputs = torch.cat(encoder_input_list, dim=0) + level_positional_encodings = torch.cat(level_positional_encoding_list, dim=0) + device = encoder_inputs.device + # shape (num_encoder_levels, 2), from low + # resolution to high resolution + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=device) + # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = torch.cat(reference_points_list, dim=0) + reference_points = reference_points[None, :, None].repeat(batch_size, 1, self.num_encoder_levels, 1) + valid_radios = reference_points.new_ones((batch_size, self.num_encoder_levels, 2)) + # shape (num_total_query, batch_size, c) + memory = self.encoder( + query=encoder_inputs, + key=None, + value=None, + query_pos=level_positional_encodings, + key_pos=None, + attn_masks=None, + key_padding_mask=None, + query_key_padding_mask=padding_masks, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_radios=valid_radios, + ) + # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query) + memory = memory.permute(1, 2, 0) + + # from low resolution to high resolution + num_query_per_level = [e[0] * e[1] for e in spatial_shapes] + outs = torch.split(memory, num_query_per_level, dim=-1) + outs = [x.reshape(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) for i, x in enumerate(outs)] + + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1): + x = feats[i] + cur_feat = self.lateral_convs[i](x) + y = cur_feat + F.interpolate(outs[-1], size=cur_feat.shape[-2:], mode="bilinear", align_corners=False) + y = self.output_convs[i](y) + outs.append(y) + multi_scale_features = outs[: self.num_outs] + + mask_feature = self.mask_feature(outs[-1]) + return mask_feature, multi_scale_features diff --git a/src/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py b/src/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..adf0062691e4889612e118f28ced853cd0bc33db --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .encoder_decoder_mask2former import EncoderDecoderMask2Former diff --git a/src/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py b/src/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe572c9d317303bff8d51b85217d144906ebfe7 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.core import add_prefix +from mmseg.models import builder +from mmseg.models.builder import SEGMENTORS +from mmseg.models.segmentors.base import BaseSegmentor +from mmseg.ops import resize + + +@SEGMENTORS.register_module() +class EncoderDecoderMask2Former(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + + def __init__( + self, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None, + ): + super(EncoderDecoderMask2Former, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get("pretrained") is None, "both backbone and segmentor set pretrained weight" + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + if neck is not None: + self.neck = builder.build_neck(neck) + decode_head.update(train_cfg=train_cfg) + decode_head.update(test_cfg=test_cfg) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + + def _init_auxiliary_head(self, auxiliary_head): + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(builder.build_head(head_cfg)) + else: + self.auxiliary_head = builder.build_head(auxiliary_head) + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + out = resize(input=out, size=img.shape[2:], mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(x, img_metas, gt_semantic_seg, **kwargs) + + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return seg_logits + + def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, f"aux_{idx}")) + else: + loss_aux = self.auxiliary_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, "aux")) + + return losses + + def forward_dummy(self, img): + """Dummy forward function.""" + seg_logit = self.encode_decode(img, None) + + return seg_logit + + def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, img_metas, gt_semantic_seg, **kwargs) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train(x, img_metas, gt_semantic_seg) + losses.update(loss_aux) + + return losses + + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + num_classes = self.num_classes + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + if rescale: + preds = resize( + preds, + size=img_meta[0]["ori_shape"][:2], + mode="bilinear", + align_corners=self.align_corners, + warning=False, + ) + return preds + + def whole_inference(self, img, img_meta, rescale): + """Inference with full image.""" + + seg_logit = self.encode_decode(img, img_meta) + if rescale: + # support dynamic shape for onnx + if torch.onnx.is_in_onnx_export(): + size = img.shape[2:] + else: + size = img_meta[0]["ori_shape"][:2] + seg_logit = resize(seg_logit, size=size, mode="bilinear", align_corners=self.align_corners, warning=False) + + return seg_logit + + def inference(self, img, img_meta, rescale): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output segmentation map. + """ + + assert self.test_cfg.mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if self.test_cfg.mode == "slide": + seg_logit = self.slide_inference(img, img_meta, rescale) + else: + seg_logit = self.whole_inference(img, img_meta, rescale) + output = F.softmax(seg_logit, dim=1) + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + seg_logit = self.inference(img, img_meta, rescale) + seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + seg_pred = seg_pred.unsqueeze(0) + return seg_pred + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) + seg_logit += cur_seg_logit + seg_logit /= len(imgs) + seg_pred = seg_logit.argmax(dim=1) + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/src/dinov2/eval/segmentation_m2f/models/utils/__init__.py b/src/dinov2/eval/segmentation_m2f/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7fdc1668b1015c8feea8fa1a4691bc0ebdbd936 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .assigner import MaskHungarianAssigner +from .point_sample import get_uncertain_point_coords_with_randomness +from .positional_encoding import LearnedPositionalEncoding, SinePositionalEncoding +from .transformer import DetrTransformerDecoder, DetrTransformerDecoderLayer, DynamicConv, Transformer diff --git a/src/dinov2/eval/segmentation_m2f/models/utils/assigner.py b/src/dinov2/eval/segmentation_m2f/models/utils/assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb08fc1bb2e36336989b45a1d3850f260c05963 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/utils/assigner.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from abc import ABCMeta, abstractmethod + +import torch + +from ..builder import MASK_ASSIGNERS, build_match_cost + +try: + from scipy.optimize import linear_sum_assignment +except ImportError: + linear_sum_assignment = None + + +class AssignResult(metaclass=ABCMeta): + """Collection of assign results.""" + + def __init__(self, num_gts, gt_inds, labels): + self.num_gts = num_gts + self.gt_inds = gt_inds + self.labels = labels + + @property + def info(self): + info = { + "num_gts": self.num_gts, + "gt_inds": self.gt_inds, + "labels": self.labels, + } + return info + + +class BaseAssigner(metaclass=ABCMeta): + """Base assigner that assigns boxes to ground truth boxes.""" + + @abstractmethod + def assign(self, masks, gt_masks, gt_masks_ignore=None, gt_labels=None): + """Assign boxes to either a ground truth boxes or a negative boxes.""" + pass + + +@MASK_ASSIGNERS.register_module() +class MaskHungarianAssigner(BaseAssigner): + """Computes one-to-one matching between predictions and ground truth for + mask. + + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of three components: + classification cost, regression L1 cost and regression iou cost. The + targets don't include the no_object, so generally there are more + predictions than targets. After the one-to-one matching, the un-matched + are treated as backgrounds. Thus each query prediction will be assigned + with `0` or a positive integer indicating the ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + cls_cost (obj:`mmcv.ConfigDict`|dict): Classification cost config. + mask_cost (obj:`mmcv.ConfigDict`|dict): Mask cost config. + dice_cost (obj:`mmcv.ConfigDict`|dict): Dice cost config. + """ + + def __init__( + self, + cls_cost=dict(type="ClassificationCost", weight=1.0), + dice_cost=dict(type="DiceCost", weight=1.0), + mask_cost=dict(type="MaskFocalCost", weight=1.0), + ): + self.cls_cost = build_match_cost(cls_cost) + self.dice_cost = build_match_cost(dice_cost) + self.mask_cost = build_match_cost(mask_cost) + + def assign(self, cls_pred, mask_pred, gt_labels, gt_masks, img_meta, gt_masks_ignore=None, eps=1e-7): + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + + Args: + mask_pred (Tensor): Predicted mask, shape [num_query, h, w] + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_masks (Tensor): Ground truth mask, shape [num_gt, h, w]. + gt_labels (Tensor): Label of `gt_masks`, shape (num_gt,). + img_meta (dict): Meta information for current image. + gt_masks_ignore (Tensor, optional): Ground truth masks that are + labelled as `ignored`. Default None. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + assert gt_masks_ignore is None, "Only case when gt_masks_ignore is None is supported." + num_gts, num_queries = gt_labels.shape[0], cls_pred.shape[0] + + # 1. assign -1 by default + assigned_gt_inds = cls_pred.new_full((num_queries,), -1, dtype=torch.long) + assigned_labels = cls_pred.new_full((num_queries,), -1, dtype=torch.long) + if num_gts == 0 or num_queries == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) + + # 2. compute the weighted costs + # classification and maskcost. + if self.cls_cost.weight != 0 and cls_pred is not None: + cls_cost = self.cls_cost(cls_pred, gt_labels) + else: + cls_cost = 0 + + if self.mask_cost.weight != 0: + # mask_pred shape = [nq, h, w] + # gt_mask shape = [ng, h, w] + # mask_cost shape = [nq, ng] + mask_cost = self.mask_cost(mask_pred, gt_masks) + else: + mask_cost = 0 + + if self.dice_cost.weight != 0: + dice_cost = self.dice_cost(mask_pred, gt_masks) + else: + dice_cost = 0 + cost = cls_cost + mask_cost + dice_cost + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' "to install scipy first.") + + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + matched_row_inds = torch.from_numpy(matched_row_inds).to(cls_pred.device) + matched_col_inds = torch.from_numpy(matched_col_inds).to(cls_pred.device) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] + return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) diff --git a/src/dinov2/eval/segmentation_m2f/models/utils/point_sample.py b/src/dinov2/eval/segmentation_m2f/models/utils/point_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1134082bafb51432618a9632592db070f87284 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/utils/point_sample.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +from mmcv.ops import point_sample + + +def get_uncertainty(mask_pred, labels): + """Estimate uncertainty based on pred logits. + + We estimate uncertainty as L1 distance between 0.0 and the logits + prediction in 'mask_pred' for the foreground class in `classes`. + + Args: + mask_pred (Tensor): mask predication logits, shape (num_rois, + num_classes, mask_height, mask_width). + + labels (list[Tensor]): Either predicted or ground truth label for + each predicted mask, of length num_rois. + + Returns: + scores (Tensor): Uncertainty scores with the most uncertain + locations having the highest uncertainty score, + shape (num_rois, 1, mask_height, mask_width) + """ + if mask_pred.shape[1] == 1: + gt_class_logits = mask_pred.clone() + else: + inds = torch.arange(mask_pred.shape[0], device=mask_pred.device) + gt_class_logits = mask_pred[inds, labels].unsqueeze(1) + return -torch.abs(gt_class_logits) + + +def get_uncertain_point_coords_with_randomness( + mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio +): + """Get ``num_points`` most uncertain points with random points during + train. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'get_uncertainty()' function that takes point's logit prediction as + input. + + Args: + mask_pred (Tensor): A tensor of shape (num_rois, num_classes, + mask_height, mask_width) for class-specific or class-agnostic + prediction. + labels (list): The ground truth class for each instance. + num_points (int): The number of points to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled + via importnace sampling. + + Returns: + point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) + that contains the coordinates sampled points. + """ + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = mask_pred.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device) + point_logits = point_sample(mask_pred, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = get_uncertainty(point_logits, labels) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=mask_pred.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device) + point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) + return point_coords diff --git a/src/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py b/src/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5d6fabe946d06fe97cc799da47bae93758b34e --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING +from mmcv.runner import BaseModule + + +@POSITIONAL_ENCODING.register_module() +class SinePositionalEncoding(BaseModule): + """Position encoding with sine and cosine functions. + + See `End-to-End Object Detection with Transformers + `_ for details. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. Note the final returned dimension + for each position is 2 times of this value. + temperature (int, optional): The temperature used for scaling + the position embedding. Defaults to 10000. + normalize (bool, optional): Whether to normalize the position + embedding. Defaults to False. + scale (float, optional): A scale factor that scales the position + embedding. The scale will be used only when `normalize` is True. + Defaults to 2*pi. + eps (float, optional): A value added to the denominator for + numerical stability. Defaults to 1e-6. + offset (float): offset add to embed when do the normalization. + Defaults to 0. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__( + self, num_feats, temperature=10000, normalize=False, scale=2 * math.pi, eps=1e-6, offset=0.0, init_cfg=None + ): + super(SinePositionalEncoding, self).__init__(init_cfg) + if normalize: + assert isinstance(scale, (float, int)), ( + "when normalize is set," "scale should be provided and in float or int type, " f"found {type(scale)}" + ) + self.num_feats = num_feats + self.temperature = temperature + self.normalize = normalize + self.scale = scale + self.eps = eps + self.offset = offset + + def forward(self, mask): + """Forward function for `SinePositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + # For convenience of exporting to ONNX, it's required to convert + # `masks` from bool to int. + mask = mask.to(torch.int) + not_mask = 1 - mask # logical_not + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + y_embed = (y_embed + self.offset) / (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = (x_embed + self.offset) / (x_embed[:, :, -1:] + self.eps) * self.scale + dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + # use `view` instead of `flatten` for dynamically exporting to ONNX + B, H, W = mask.size() + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f"(num_feats={self.num_feats}, " + repr_str += f"temperature={self.temperature}, " + repr_str += f"normalize={self.normalize}, " + repr_str += f"scale={self.scale}, " + repr_str += f"eps={self.eps})" + return repr_str + + +@POSITIONAL_ENCODING.register_module() +class LearnedPositionalEncoding(BaseModule): + """Position embedding with learnable embedding weights. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. The final returned dimension for + each position is 2 times of this value. + row_num_embed (int, optional): The dictionary size of row embeddings. + Default 50. + col_num_embed (int, optional): The dictionary size of col embeddings. + Default 50. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, num_feats, row_num_embed=50, col_num_embed=50, init_cfg=dict(type="Uniform", layer="Embedding")): + super(LearnedPositionalEncoding, self).__init__(init_cfg) + self.row_embed = nn.Embedding(row_num_embed, num_feats) + self.col_embed = nn.Embedding(col_num_embed, num_feats) + self.num_feats = num_feats + self.row_num_embed = row_num_embed + self.col_num_embed = col_num_embed + + def forward(self, mask): + """Forward function for `LearnedPositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + h, w = mask.shape[-2:] + x = torch.arange(w, device=mask.device) + y = torch.arange(h, device=mask.device) + x_embed = self.col_embed(x) + y_embed = self.row_embed(y) + pos = ( + torch.cat((x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(1, w, 1)), dim=-1) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(mask.shape[0], 1, 1, 1) + ) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f"(num_feats={self.num_feats}, " + repr_str += f"row_num_embed={self.row_num_embed}, " + repr_str += f"col_num_embed={self.col_num_embed})" + return repr_str diff --git a/src/dinov2/eval/segmentation_m2f/models/utils/transformer.py b/src/dinov2/eval/segmentation_m2f/models/utils/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8befe6011a34d5ccecb82c8b17b61e19f732f96b --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/models/utils/transformer.py @@ -0,0 +1,989 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import warnings +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import Linear, build_activation_layer, build_norm_layer, xavier_init +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.registry import FEEDFORWARD_NETWORK, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE +from mmcv.cnn.bricks.transformer import BaseTransformerLayer, TransformerLayerSequence, build_transformer_layer_sequence +from mmcv.runner.base_module import BaseModule, Sequential +from mmcv.utils import deprecated_api_warning, to_2tuple +from torch.nn.init import normal_ + +from ..builder import TRANSFORMER + +try: + from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention + +except ImportError: + warnings.warn( + "`MultiScaleDeformableAttention` in MMCV has been moved to " + "`mmcv.ops.multi_scale_deform_attn`, please update your MMCV" + ) + from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1 + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + + def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"): + + super(AdaptivePadding, self).__init__() + + assert padding in ("same", "corner") + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == "corner": + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == "same": + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return x + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + to gets fully covered by filter and stride you specified.. + Default: True. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding="corner", + dilation=1, + bias=False, + norm_cfg=dict(type="LN"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding + ) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f"Expect " f"input_size is " f"`Sequence` " f"but get {input_size}" + + H, W = input_size + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = ( + H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1 + ) // self.sampler.stride[0] + 1 + out_w = ( + W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1 + ) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size + + +def inverse_sigmoid(x, eps=1e-5): + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the + inverse. + eps (float): EPS avoid numerical + overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse + function of sigmoid, has same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +@FEEDFORWARD_NETWORK.register_module(force=True) +class FFN(BaseModule): + """Implements feed-forward networks (FFNs) with identity connection. + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + add_identity (bool, optional): Whether to add the + identity connection. Default: `True`. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + @deprecated_api_warning({"dropout": "ffn_drop", "add_residual": "add_identity"}, cls_name="FFN") + def __init__( + self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + act_cfg=dict(type="ReLU", inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True, + init_cfg=None, + with_cp=False, + **kwargs, + ): + super().__init__(init_cfg) + assert num_fcs >= 2, "num_fcs should be no less " f"than 2. got {num_fcs}." + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + self.with_cp = with_cp + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append(Sequential(Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(ffn_drop))) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + @deprecated_api_warning({"residual": "identity"}, cls_name="FFN") + def forward(self, x, identity=None): + """Forward function for `FFN`. + The function would add x to the output tensor if residue is None. + """ + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.layers, x) + else: + out = self.layers(x) + + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +@TRANSFORMER_LAYER.register_module() +class DetrTransformerDecoderLayer(BaseTransformerLayer): + """Implements decoder layer in DETR transformer. + + Args: + attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): + Configs for self_attention or cross_attention, the order + should be consistent with it in `operation_order`. If it is + a dict, it would be expand to the number of attention in + `operation_order`. + feedforward_channels (int): The hidden dimension for FFNs. + ffn_dropout (float): Probability of an element to be zeroed + in ffn. Default 0.0. + operation_order (tuple[str]): The execution order of operation + in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). + Default:None + act_cfg (dict): The activation config for FFNs. Default: `LN` + norm_cfg (dict): Config dict for normalization layer. + Default: `LN`. + ffn_num_fcs (int): The number of fully-connected layers in FFNs. + Default:2. + """ + + def __init__( + self, + attn_cfgs, + feedforward_channels, + ffn_dropout=0.0, + operation_order=None, + act_cfg=dict(type="ReLU", inplace=True), + norm_cfg=dict(type="LN"), + ffn_num_fcs=2, + **kwargs, + ): + super(DetrTransformerDecoderLayer, self).__init__( + attn_cfgs=attn_cfgs, + feedforward_channels=feedforward_channels, + ffn_dropout=ffn_dropout, + operation_order=operation_order, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + ffn_num_fcs=ffn_num_fcs, + **kwargs, + ) + assert len(operation_order) == 6 + assert set(operation_order) == set(["self_attn", "norm", "cross_attn", "ffn"]) + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DetrTransformerEncoder(TransformerLayerSequence): + """TransformerEncoder of DETR. + + Args: + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. Only used when `self.pre_norm` is `True` + """ + + def __init__(self, *args, post_norm_cfg=dict(type="LN"), **kwargs): + super(DetrTransformerEncoder, self).__init__(*args, **kwargs) + if post_norm_cfg is not None: + self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None + else: + assert not self.pre_norm, f"Use prenorm in " f"{self.__class__.__name__}," f"Please specify post_norm_cfg" + self.post_norm = None + + def forward(self, *args, **kwargs): + """Forward function for `TransformerCoder`. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + x = super(DetrTransformerEncoder, self).forward(*args, **kwargs) + if self.post_norm is not None: + x = self.post_norm(x) + return x + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DetrTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + + def __init__(self, *args, post_norm_cfg=dict(type="LN"), return_intermediate=False, **kwargs): + + super(DetrTransformerDecoder, self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + if post_norm_cfg is not None: + self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] + else: + self.post_norm = None + + def forward(self, query, *args, **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + if not self.return_intermediate: + x = super().forward(query, *args, **kwargs) + if self.post_norm: + x = self.post_norm(x)[None] + return x + + intermediate = [] + for layer in self.layers: + query = layer(query, *args, **kwargs) + if self.return_intermediate: + if self.post_norm is not None: + intermediate.append(self.post_norm(query)) + else: + intermediate.append(query) + return torch.stack(intermediate) + + +@TRANSFORMER.register_module() +class Transformer(BaseModule): + """Implements the DETR transformer. + + Following the official DETR implementation, this module copy-paste + from torch.nn.Transformer with modifications: + + * positional encodings are passed in MultiheadAttention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers + + See `paper: End-to-End Object Detection with Transformers + `_ for details. + + Args: + encoder (`mmcv.ConfigDict` | Dict): Config of + TransformerEncoder. Defaults to None. + decoder ((`mmcv.ConfigDict` | Dict)): Config of + TransformerDecoder. Defaults to None + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Defaults to None. + """ + + def __init__(self, encoder=None, decoder=None, init_cfg=None): + super(Transformer, self).__init__(init_cfg=init_cfg) + self.encoder = build_transformer_layer_sequence(encoder) + self.decoder = build_transformer_layer_sequence(decoder) + self.embed_dims = self.encoder.embed_dims + + def init_weights(self): + # follow the official DETR to init parameters + for m in self.modules(): + if hasattr(m, "weight") and m.weight.dim() > 1: + xavier_init(m, distribution="uniform") + self._is_init = True + + def forward(self, x, mask, query_embed, pos_embed): + """Forward function for `Transformer`. + + Args: + x (Tensor): Input query with shape [bs, c, h, w] where + c = embed_dims. + mask (Tensor): The key_padding_mask used for encoder and decoder, + with shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, with shape + [num_query, c]. + pos_embed (Tensor): The positional encoding for encoder and + decoder, with the same shape as `x`. + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - out_dec: Output from decoder. If return_intermediate_dec \ + is True output has shape [num_dec_layers, bs, + num_query, embed_dims], else has shape [1, bs, \ + num_query, embed_dims]. + - memory: Output results from encoder, with shape \ + [bs, embed_dims, h, w]. + """ + bs, c, h, w = x.shape + # use `view` instead of `flatten` for dynamically exporting to ONNX + x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] + pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] + mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w] + memory = self.encoder(query=x, key=None, value=None, query_pos=pos_embed, query_key_padding_mask=mask) + target = torch.zeros_like(query_embed) + # out_dec: [num_layers, num_query, bs, dim] + out_dec = self.decoder( + query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask + ) + out_dec = out_dec.transpose(1, 2) + memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) + return out_dec, memory + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DeformableDetrTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + coder_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + + def __init__(self, *args, return_intermediate=False, **kwargs): + + super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + + def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + reference_points (Tensor): The reference + points of offset. has shape + (bs, num_query, 4) when as_two_stage, + otherwise has shape ((bs, num_query, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + reg_branch: (obj:`nn.ModuleList`): Used for + refining the regression results. Only would + be passed when with_box_refine is True, + otherwise would be passed a `None`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = ( + reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] + ) + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * valid_ratios[:, None] + output = layer(output, *args, reference_points=reference_points_input, **kwargs) + output = output.permute(1, 0, 2) + + if reg_branches is not None: + tmp = reg_branches[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + output = output.permute(1, 0, 2) + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points) + + return output, reference_points + + +@TRANSFORMER.register_module() +class DeformableDetrTransformer(Transformer): + """Implements the DeformableDETR transformer. + + Args: + as_two_stage (bool): Generate query from encoder features. + Default: False. + num_feature_levels (int): Number of feature maps from FPN: + Default: 4. + two_stage_num_proposals (int): Number of proposals when set + `as_two_stage` as True. Default: 300. + """ + + def __init__(self, as_two_stage=False, num_feature_levels=4, two_stage_num_proposals=300, **kwargs): + super(DeformableDetrTransformer, self).__init__(**kwargs) + self.as_two_stage = as_two_stage + self.num_feature_levels = num_feature_levels + self.two_stage_num_proposals = two_stage_num_proposals + self.embed_dims = self.encoder.embed_dims + self.init_layers() + + def init_layers(self): + """Initialize layers of the DeformableDetrTransformer.""" + self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dims)) + + if self.as_two_stage: + self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) + self.enc_output_norm = nn.LayerNorm(self.embed_dims) + self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2) + self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) + else: + self.reference_points = nn.Linear(self.embed_dims, 2) + + def init_weights(self): + """Initialize the transformer weights.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + if not self.as_two_stage: + xavier_init(self.reference_points, distribution="uniform", bias=0.0) + normal_(self.level_embeds) + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + """Generate proposals from encoded memory. + + Args: + memory (Tensor) : The output of encoder, + has shape (bs, num_key, embed_dim). num_key is + equal the number of points on feature map from + all level. + memory_padding_mask (Tensor): Padding mask for memory. + has shape (bs, num_key). + spatial_shapes (Tensor): The shape of all feature maps. + has shape (num_level, 2). + + Returns: + tuple: A tuple of feature map and bbox prediction. + + - output_memory (Tensor): The input of decoder, \ + has shape (bs, num_key, embed_dim). num_key is \ + equal the number of points on feature map from \ + all levels. + - output_proposals (Tensor): The normalized proposal \ + after a inverse sigmoid, has shape \ + (bs, num_keys, 4). + """ + + N, S, C = memory.shape + proposals = [] + _cur = 0 + for lvl, (H, W) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device), + torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device), + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(N, -1, 4) + proposals.append(proposal) + _cur += H * W + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf")) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + """Get the reference points used in decoder. + + Args: + spatial_shapes (Tensor): The shape of all + feature maps, has shape (num_level, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + device (obj:`device`): The device where + reference_points should be. + + Returns: + Tensor: reference points used in decoder, has \ + shape (bs, num_keys, num_levels, 2). + """ + reference_points_list = [] + for lvl, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device), + torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def get_valid_ratio(self, mask): + """Get the valid radios of feature maps of all level.""" + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000): + """Get the position embedding of proposal.""" + scale = 2 * math.pi + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def forward( + self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, reg_branches=None, cls_branches=None, **kwargs + ): + """Forward function for `Transformer`. + + Args: + mlvl_feats (list(Tensor)): Input queries from + different level. Each element has shape + [bs, embed_dims, h, w]. + mlvl_masks (list(Tensor)): The key_padding_mask from + different level used for encoder and decoder, + each element has shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, + with shape [num_query, c]. + mlvl_pos_embeds (list(Tensor)): The positional encoding + of feats from different level, has the shape + [bs, embed_dims, h, w]. + reg_branches (obj:`nn.ModuleList`): Regression heads for + feature maps from each decoder layer. Only would + be passed when + `with_box_refine` is True. Default to None. + cls_branches (obj:`nn.ModuleList`): Classification heads + for feature maps from each decoder layer. Only would + be passed when `as_two_stage` + is True. Default to None. + + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - inter_states: Outputs from decoder. If + return_intermediate_dec is True output has shape \ + (num_dec_layers, bs, num_query, embed_dims), else has \ + shape (1, bs, num_query, embed_dims). + - init_reference_out: The initial value of reference \ + points, has shape (bs, num_queries, 4). + - inter_references_out: The internal value of reference \ + points in decoder, has shape \ + (num_dec_layers, bs,num_query, embed_dims) + - enc_outputs_class: The classification score of \ + proposals generated from \ + encoder's feature maps, has shape \ + (batch, h*w, num_classes). \ + Only would be returned when `as_two_stage` is True, \ + otherwise None. + - enc_outputs_coord_unact: The regression results \ + generated from encoder's feature maps., has shape \ + (batch, h*w, 4). Only would \ + be returned when `as_two_stage` is True, \ + otherwise None. + """ + assert self.as_two_stage or query_embed is not None + + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + feat_flatten.append(feat) + mask_flatten.append(mask) + feat_flatten = torch.cat(feat_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feat_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1) + + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device) + + feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + memory = self.encoder( + query=feat_flatten, + key=None, + value=None, + query_pos=lvl_pos_embed_flatten, + query_key_padding_mask=mask_flatten, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + **kwargs, + ) + + memory = memory.permute(1, 0, 2) + bs, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) + enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory) + enc_outputs_coord_unact = reg_branches[self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) + query_pos, query = torch.split(pos_trans_out, c, dim=2) + else: + query_pos, query = torch.split(query_embed, c, dim=1) + query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) + query = query.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_pos).sigmoid() + init_reference_out = reference_points + + # decoder + query = query.permute(1, 0, 2) + memory = memory.permute(1, 0, 2) + query_pos = query_pos.permute(1, 0, 2) + inter_states, inter_references = self.decoder( + query=query, + key=None, + value=memory, + query_pos=query_pos, + key_padding_mask=mask_flatten, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + **kwargs, + ) + + inter_references_out = inter_references + if self.as_two_stage: + return inter_states, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact + return inter_states, init_reference_out, inter_references_out, None, None + + +@TRANSFORMER.register_module() +class DynamicConv(BaseModule): + """Implements Dynamic Convolution. + + This module generate parameters for each sample and + use bmm to implement 1*1 convolution. Code is modified + from the `official github repo `_ . + + Args: + in_channels (int): The input feature channel. + Defaults to 256. + feat_channels (int): The inner feature channel. + Defaults to 64. + out_channels (int, optional): The output feature channel. + When not specified, it will be set to `in_channels` + by default + input_feat_shape (int): The shape of input feature. + Defaults to 7. + with_proj (bool): Project two-dimentional feature to + one-dimentional feature. Default to True. + act_cfg (dict): The activation config for DynamicConv. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__( + self, + in_channels=256, + feat_channels=64, + out_channels=None, + input_feat_shape=7, + with_proj=True, + act_cfg=dict(type="ReLU", inplace=True), + norm_cfg=dict(type="LN"), + init_cfg=None, + ): + super(DynamicConv, self).__init__(init_cfg) + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.input_feat_shape = input_feat_shape + self.with_proj = with_proj + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.in_channels * self.feat_channels + self.num_params_out = self.out_channels * self.feat_channels + self.dynamic_layer = nn.Linear(self.in_channels, self.num_params_in + self.num_params_out) + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + num_output = self.out_channels * input_feat_shape**2 + if self.with_proj: + self.fc_layer = nn.Linear(num_output, self.out_channels) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, param_feature, input_feature): + """Forward function for `DynamicConv`. + + Args: + param_feature (Tensor): The feature can be used + to generate the parameter, has shape + (num_all_proposals, in_channels). + input_feature (Tensor): Feature that + interact with parameters, has shape + (num_all_proposals, in_channels, H, W). + + Returns: + Tensor: The output feature has shape + (num_all_proposals, out_channels). + """ + input_feature = input_feature.flatten(2).permute(2, 0, 1) + + input_feature = input_feature.permute(1, 0, 2) + parameters = self.dynamic_layer(param_feature) + + param_in = parameters[:, : self.num_params_in].view(-1, self.in_channels, self.feat_channels) + param_out = parameters[:, -self.num_params_out :].view(-1, self.feat_channels, self.out_channels) + + # input_feature has shape (num_all_proposals, H*W, in_channels) + # param_in has shape (num_all_proposals, in_channels, feat_channels) + # feature has shape (num_all_proposals, H*W, feat_channels) + features = torch.bmm(input_feature, param_in) + features = self.norm_in(features) + features = self.activation(features) + + # param_out has shape (batch_size, feat_channels, out_channels) + features = torch.bmm(features, param_out) + features = self.norm_out(features) + features = self.activation(features) + + if self.with_proj: + features = features.flatten(1) + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features diff --git a/src/dinov2/eval/segmentation_m2f/ops/modules/__init__.py b/src/dinov2/eval/segmentation_m2f/ops/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49aa8fe612fd4c088e294707c5ee16bd1cb5b5e7 --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/ops/modules/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/fundamentalvision/Deformable-DETR/tree/main/models/ops/modules +# https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 + +from .ms_deform_attn import MSDeformAttn diff --git a/src/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py b/src/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b4fa23712e87d1a2682b57e71ee37fe8524cff --- /dev/null +++ b/src/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import warnings + +import torch +import torch.nn.functional as F +from torch import nn +from torch.autograd import Function +from torch.cuda.amp import custom_fwd +from torch.nn.init import constant_, xavier_uniform_ + + +class MSDeformAttnFunction(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward( + ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step + ): + output = ms_deform_attn_core_pytorch( + value, + value_spatial_shapes, + # value_level_start_index, + sampling_locations, + attention_weights, + ) + return output + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) + return output.transpose(1, 2).contiguous() + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, ratio=1.0): + """Multi-Scale Deformable Attention Module. + + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError("d_model must be divisible by n_heads, " "but got {} and {}".format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 + # which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn( + "You'd better set d_model in MSDeformAttn to make " + "the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation." + ) + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + self.ratio = ratio + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, int(d_model * ratio)) + self.output_proj = nn.Linear(int(d_model * ratio), d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def forward( + self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None, + ): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + # print(query.shape) + # print(reference_points.shape) + # print(input_flatten.shape) + # print(input_spatial_shapes.shape) + # print(input_level_start_index.shape) + # print(input_spatial_shapes) + # print(input_level_start_index) + + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + + value = value.view(N, Len_in, self.n_heads, int(self.ratio * self.d_model) // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1]) + ) + output = MSDeformAttnFunction.apply( + value, + input_spatial_shapes, + input_level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) + return output diff --git a/src/dinov2/eval/setup.py b/src/dinov2/eval/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..959128c0673cc51036dbf17dcc4ee68a037988fb --- /dev/null +++ b/src/dinov2/eval/setup.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +from typing import Any, List, Optional, Tuple + +import torch +import torch.backends.cudnn as cudnn + +from dinov2.models import build_model_from_cfg +from dinov2.utils.config import setup +import dinov2.utils.utils as dinov2_utils + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parser = argparse.ArgumentParser( + description=description, + parents=parents or [], + add_help=add_help, + ) + parser.add_argument( + "--config-file", + type=str, + help="Model configuration file", + ) + parser.add_argument( + "--pretrained-weights", + type=str, + help="Pretrained model weights", + ) + parser.add_argument( + "--output-dir", + default="", + type=str, + help="Output directory to write results and logs", + ) + parser.add_argument( + "--opts", + help="Extra configuration options", + default=[], + nargs="+", + ) + return parser + + +def get_autocast_dtype(config): + teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype + if teacher_dtype_str == "fp16": + return torch.half + elif teacher_dtype_str == "bf16": + return torch.bfloat16 + else: + return torch.float + + +def build_model_for_eval(config, pretrained_weights): + model, _ = build_model_from_cfg(config, only_teacher=True) + dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") + model.eval() + model.cuda() + return model + + +def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: + cudnn.benchmark = True + config = setup(args) + model = build_model_for_eval(config, args.pretrained_weights) + autocast_dtype = get_autocast_dtype(config) + return model, autocast_dtype diff --git a/src/dinov2/eval/utils.py b/src/dinov2/eval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c50576b1940587ee64b7a422e2e96b475d60fd39 --- /dev/null +++ b/src/dinov2/eval/utils.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, Optional + +import torch +from torch import nn +from torchmetrics import MetricCollection + +from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader +import dinov2.distributed as distributed +from dinov2.logging import MetricLogger + + +logger = logging.getLogger("dinov2") + + +class ModelWithNormalize(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, samples): + return nn.functional.normalize(self.model(samples), dim=1, p=2) + + +class ModelWithIntermediateLayers(nn.Module): + def __init__(self, feature_model, n_last_blocks, autocast_ctx): + super().__init__() + self.feature_model = feature_model + self.feature_model.eval() + self.n_last_blocks = n_last_blocks + self.autocast_ctx = autocast_ctx + + def forward(self, images): + with torch.inference_mode(): + with self.autocast_ctx(): + features = self.feature_model.get_intermediate_layers( + images, self.n_last_blocks, return_class_token=True + ) + return features + + +@torch.inference_mode() +def evaluate( + model: nn.Module, + data_loader, + postprocessors: Dict[str, nn.Module], + metrics: Dict[str, MetricCollection], + device: torch.device, + criterion: Optional[nn.Module] = None, +): + model.eval() + if criterion is not None: + criterion.eval() + + for metric in metrics.values(): + metric = metric.to(device) + + metric_logger = MetricLogger(delimiter=" ") + header = "Test:" + + for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header): + outputs = model(samples.to(device)) + targets = targets.to(device) + + if criterion is not None: + loss = criterion(outputs, targets) + metric_logger.update(loss=loss.item()) + + for k, metric in metrics.items(): + metric_inputs = postprocessors[k](outputs, targets) + metric.update(**metric_inputs) + + metric_logger.synchronize_between_processes() + logger.info(f"Averaged stats: {metric_logger}") + + stats = {k: metric.compute() for k, metric in metrics.items()} + metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + return metric_logger_stats, stats + + +def all_gather_and_flatten(tensor_rank): + tensor_all_ranks = torch.empty( + distributed.get_global_size(), + *tensor_rank.shape, + dtype=tensor_rank.dtype, + device=tensor_rank.device, + ) + tensor_list = list(tensor_all_ranks.unbind(0)) + torch.distributed.all_gather(tensor_list, tensor_rank.contiguous()) + return tensor_all_ranks.flatten(end_dim=1) + + +def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False): + dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset) + sample_count = len(dataset_with_enumerated_targets) + data_loader = make_data_loader( + dataset=dataset_with_enumerated_targets, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + ) + return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu) + + +@torch.inference_mode() +def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False): + gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda") + metric_logger = MetricLogger(delimiter=" ") + features, all_labels = None, None + for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10): + samples = samples.cuda(non_blocking=True) + labels_rank = labels_rank.cuda(non_blocking=True) + index = index.cuda(non_blocking=True) + features_rank = model(samples).float() + + # init storage feature matrix + if features is None: + features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device) + labels_shape = list(labels_rank.shape) + labels_shape[0] = sample_count + all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device) + logger.info(f"Storing features into tensor of shape {features.shape}") + + # share indexes, features and labels between processes + index_all = all_gather_and_flatten(index).to(gather_device) + features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device) + labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device) + + # update storage feature matrix + if len(index_all) > 0: + features.index_copy_(0, index_all, features_all_ranks) + all_labels.index_copy_(0, index_all, labels_all_ranks) + + logger.info(f"Features shape: {tuple(features.shape)}") + logger.info(f"Labels shape: {tuple(all_labels.shape)}") + + assert torch.all(all_labels > -1) + + return features, all_labels diff --git a/src/dinov2/fsdp/__init__.py b/src/dinov2/fsdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed454480e0b76e761d657cc40fd097bd339d15a2 --- /dev/null +++ b/src/dinov2/fsdp/__init__.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Any + +import torch +import dinov2.distributed as distributed +from functools import partial +from fvcore.common.checkpoint import Checkpointer +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp import StateDictType +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.distributed.fsdp._runtime_utils import _reshard + + +def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()): + sharding_strategy_dict = { + "NO_SHARD": ShardingStrategy.NO_SHARD, + "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP, + "FULL_SHARD": ShardingStrategy.FULL_SHARD, + } + + dtype_dict = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + + mixed_precision_config = MixedPrecision( + param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype], + reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype], + buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype], + ) + + sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy] + + local_rank = distributed.get_local_rank() + + fsdp_wrapper = partial( + FSDP, + sharding_strategy=sharding_strategy_config, + mixed_precision=mixed_precision_config, + device_id=local_rank, + sync_module_states=True, + use_orig_params=True, + auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap), + ) + return fsdp_wrapper + + +def is_fsdp(x): + return isinstance(x, FSDP) + + +def is_sharded_fsdp(x): + return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD + + +def free_if_fsdp(x): + if is_sharded_fsdp(x): + handles = x._handles + true_list = [True for h in handles] + _reshard(x, handles, true_list) + + +def get_fsdp_modules(x): + return FSDP.fsdp_modules(x) + + +def reshard_fsdp_model(x): + for m in get_fsdp_modules(x): + free_if_fsdp(m) + + +def rankstr(): + return f"rank_{distributed.get_global_rank()}" + + +class FSDPCheckpointer(Checkpointer): + def save(self, name: str, **kwargs: Any) -> None: + """ + Dump model and checkpointables to a file. + + Args: + name (str): name of the file. + kwargs (dict): extra arbitrary data to save. + """ + if not self.save_dir or not self.save_to_disk: + return + + data = {} + with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + data["model"] = self.model.state_dict() + + # data["model"] = self.model.state_dict() + for key, obj in self.checkpointables.items(): + data[key] = obj.state_dict() + data.update(kwargs) + + basename = f"{name}.{rankstr()}.pth" + save_file = os.path.join(self.save_dir, basename) + assert os.path.basename(save_file) == basename, basename + self.logger.info("Saving checkpoint to {}".format(save_file)) + with self.path_manager.open(save_file, "wb") as f: + torch.save(data, f) + self.tag_last_checkpoint(basename) + + def load(self, *args, **kwargs): + with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + return super().load(*args, **kwargs) + + def has_checkpoint(self) -> bool: + """ + Returns: + bool: whether a checkpoint exists in the target directory. + """ + save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") + return self.path_manager.exists(save_file) + + def get_checkpoint_file(self) -> str: + """ + Returns: + str: The latest checkpoint file in target directory. + """ + save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") + try: + with self.path_manager.open(save_file, "r") as f: + last_saved = f.read().strip() + except IOError: + # if file doesn't exist, maybe because it has just been + # deleted by a separate process + return "" + # pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got + # `Union[bytes, str]`. + return os.path.join(self.save_dir, last_saved) + + def tag_last_checkpoint(self, last_filename_basename: str) -> None: + """ + Tag the last checkpoint. + + Args: + last_filename_basename (str): the basename of the last filename. + """ + if distributed.is_enabled(): + torch.distributed.barrier() + save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") + with self.path_manager.open(save_file, "w") as f: + f.write(last_filename_basename) # pyre-ignore + + +ShardedGradScaler = ShardedGradScaler diff --git a/src/dinov2/hub/__init__.py b/src/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/src/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/dinov2/hub/backbones.py b/src/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002 --- /dev/null +++ b/src/dinov2/hub/backbones.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/src/dinov2/hub/classifiers.py b/src/dinov2/hub/classifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0841efa80ab3d564cd320d61da254af182606b --- /dev/null +++ b/src/dinov2/hub/classifiers.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch +import torch.nn as nn + +from .backbones import _make_dinov2_model +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + IMAGENET1K = "IMAGENET1K" + + +def _make_dinov2_linear_classification_head( + *, + arch_name: str = "vit_large", + patch_size: int = 14, + embed_dim: int = 1024, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) + + if pretrained: + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + layers_str = str(layers) if layers == 4 else "" + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + linear_head.load_state_dict(state_dict, strict=True) + + return linear_head + + +class _LinearClassifierWrapper(nn.Module): + def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4): + super().__init__() + self.backbone = backbone + self.linear_head = linear_head + self.layers = layers + + def forward(self, x): + if self.layers == 1: + x = self.backbone.forward_features(x) + cls_token = x["x_norm_clstoken"] + patch_tokens = x["x_norm_patchtokens"] + # fmt: off + linear_input = torch.cat([ + cls_token, + patch_tokens.mean(dim=1), + ], dim=1) + # fmt: on + elif self.layers == 4: + x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True) + # fmt: off + linear_input = torch.cat([ + x[0][1], + x[1][1], + x[2][1], + x[3][1], + x[3][0].mean(dim=1), + ], dim=1) + # fmt: on + else: + assert False, f"Unsupported number of layers: {self.layers}" + return self.linear_head(linear_input) + + +def _make_dinov2_linear_classifier( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + **kwargs, +): + backbone = _make_dinov2_model( + arch_name=arch_name, + pretrained=pretrained, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + **kwargs, + ) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + linear_head = _make_dinov2_linear_classification_head( + arch_name=arch_name, + patch_size=patch_size, + embed_dim=embed_dim, + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=num_register_tokens, + ) + + return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) + + +def dinov2_vits14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitb14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitl14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vitg14_lc( + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vits14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/src/dinov2/hub/depth/__init__.py b/src/dinov2/hub/depth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91716e58ab6158d814df8c653644d9af4c7be65c --- /dev/null +++ b/src/dinov2/hub/depth/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .decode_heads import BNHead, DPTHead +from .encoder_decoder import DepthEncoderDecoder diff --git a/src/dinov2/hub/depth/decode_heads.py b/src/dinov2/hub/depth/decode_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..f455accad38fec6ecdd53460233a564c34f434da --- /dev/null +++ b/src/dinov2/hub/depth/decode_heads.py @@ -0,0 +1,747 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy +from functools import partial +import math +import warnings + +import torch +import torch.nn as nn + +from .ops import resize + + +# XXX: (Untested) replacement for mmcv.imdenormalize() +def _imdenormalize(img, mean, std, to_bgr=True): + import numpy as np + + mean = mean.reshape(1, -1).astype(np.float64) + std = std.reshape(1, -1).astype(np.float64) + img = (img * std) + mean + if to_bgr: + img = img[::-1] + return img + + +class DepthBaseDecodeHead(nn.Module): + """Base class for BaseDecodeHead. + + Args: + in_channels (List): Input channels. + channels (int): Channels after modules, before conv_depth. + conv_layer (nn.Module): Conv layers. Default: None. + act_layer (nn.Module): Activation layers. Default: nn.ReLU. + loss_decode (dict): Config of decode loss. + Default: (). + sampler (dict|None): The config of depth map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + min_depth (int): Min depth in dataset setting. + Default: 1e-3. + max_depth (int): Max depth in dataset setting. + Default: None. + norm_layer (dict|None): Norm layers. + Default: None. + classify (bool): Whether predict depth in a cls.-reg. manner. + Default: False. + n_bins (int): The number of bins used in cls. step. + Default: 256. + bins_strategy (str): The discrete strategy used in cls. step. + Default: 'UD'. + norm_strategy (str): The norm strategy on cls. probability + distribution. Default: 'linear' + scale_up (str): Whether predict depth in a scale-up manner. + Default: False. + """ + + def __init__( + self, + in_channels, + conv_layer=None, + act_layer=nn.ReLU, + channels=96, + loss_decode=(), + sampler=None, + align_corners=False, + min_depth=1e-3, + max_depth=None, + norm_layer=None, + classify=False, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + scale_up=False, + ): + super(DepthBaseDecodeHead, self).__init__() + + self.in_channels = in_channels + self.channels = channels + self.conf_layer = conv_layer + self.act_layer = act_layer + self.loss_decode = loss_decode + self.align_corners = align_corners + self.min_depth = min_depth + self.max_depth = max_depth + self.norm_layer = norm_layer + self.classify = classify + self.n_bins = n_bins + self.scale_up = scale_up + + if self.classify: + assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID" + assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid" + + self.bins_strategy = bins_strategy + self.norm_strategy = norm_strategy + self.softmax = nn.Softmax(dim=1) + self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1) + else: + self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1) + + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs, img_metas): + """Placeholder of forward function.""" + pass + + def forward_train(self, img, inputs, img_metas, depth_gt): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): GT depth + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + depth_pred = self.forward(inputs, img_metas) + losses = self.losses(depth_pred, depth_gt) + + log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0]) + losses.update(**log_imgs) + + return losses + + def forward_test(self, inputs, img_metas): + """Forward function for testing. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + + Returns: + Tensor: Output depth map. + """ + return self.forward(inputs, img_metas) + + def depth_pred(self, feat): + """Prediction each pixel.""" + if self.classify: + logit = self.conv_depth(feat) + + if self.bins_strategy == "UD": + bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + elif self.bins_strategy == "SID": + bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device) + + # following Adabins, default linear + if self.norm_strategy == "linear": + logit = torch.relu(logit) + eps = 0.1 + logit = logit + eps + logit = logit / logit.sum(dim=1, keepdim=True) + elif self.norm_strategy == "softmax": + logit = torch.softmax(logit, dim=1) + elif self.norm_strategy == "sigmoid": + logit = torch.sigmoid(logit) + logit = logit / logit.sum(dim=1, keepdim=True) + + output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1) + + else: + if self.scale_up: + output = self.sigmoid(self.conv_depth(feat)) * self.max_depth + else: + output = self.relu(self.conv_depth(feat)) + self.min_depth + return output + + def losses(self, depth_pred, depth_gt): + """Compute depth loss.""" + loss = dict() + depth_pred = resize( + input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False + ) + if not isinstance(self.loss_decode, nn.ModuleList): + losses_decode = [self.loss_decode] + else: + losses_decode = self.loss_decode + for loss_decode in losses_decode: + if loss_decode.loss_name not in loss: + loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt) + else: + loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt) + return loss + + def log_images(self, img_path, depth_pred, depth_gt, img_meta): + import numpy as np + + show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0)) + show_img = show_img.numpy().astype(np.float32) + show_img = _imdenormalize( + show_img, + img_meta["img_norm_cfg"]["mean"], + img_meta["img_norm_cfg"]["std"], + img_meta["img_norm_cfg"]["to_rgb"], + ) + show_img = np.clip(show_img, 0, 255) + show_img = show_img.astype(np.uint8) + show_img = show_img[:, :, ::-1] + show_img = show_img.transpose(0, 2, 1) + show_img = show_img.transpose(1, 0, 2) + + depth_pred = depth_pred / torch.max(depth_pred) + depth_gt = depth_gt / torch.max(depth_gt) + + depth_pred_color = copy.deepcopy(depth_pred.detach().cpu()) + depth_gt_color = copy.deepcopy(depth_gt.detach().cpu()) + + return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color} + + +class BNHead(DepthBaseDecodeHead): + """Just a batchnorm.""" + + def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs): + super().__init__(**kwargs) + self.input_transform = input_transform + self.in_index = in_index + self.upsample = upsample + # self.bn = nn.SyncBatchNorm(self.in_channels) + if self.classify: + self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1) + else: + self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1) + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + Tensor: The transformed inputs + """ + + if "concat" in self.input_transform: + inputs = [inputs[i] for i in self.in_index] + if "resize" in self.input_transform: + inputs = [ + resize( + input=x, + size=[s * self.upsample for s in inputs[0].shape[2:]], + mode="bilinear", + align_corners=self.align_corners, + ) + for x in inputs + ] + inputs = torch.cat(inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def _forward_feature(self, inputs, img_metas=None, **kwargs): + """Forward function for feature maps before classifying each pixel with + ``self.cls_seg`` fc. + Args: + inputs (list[Tensor]): List of multi-level img features. + Returns: + feats (Tensor): A tensor of shape (batch_size, self.channels, + H, W) which is feature map for last layer of decoder head. + """ + # accept lists (for cls token) + inputs = list(inputs) + for i, x in enumerate(inputs): + if len(x) == 2: + x, cls_token = x[0], x[1] + if len(x.shape) == 2: + x = x[:, :, None, None] + cls_token = cls_token[:, :, None, None].expand_as(x) + inputs[i] = torch.cat((x, cls_token), 1) + else: + x = x[0] + if len(x.shape) == 2: + x = x[:, :, None, None] + inputs[i] = x + x = self._transform_inputs(inputs) + # feats = self.bn(x) + return x + + def forward(self, inputs, img_metas=None, **kwargs): + """Forward function.""" + output = self._forward_feature(inputs, img_metas=img_metas, **kwargs) + output = self.depth_pred(output) + return output + + +class ConvModule(nn.Module): + """A conv block that bundles conv/norm/activation layers. + + This block simplifies the usage of convolution layers, which are commonly + used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). + It is based upon three build methods: `build_conv_layer()`, + `build_norm_layer()` and `build_activation_layer()`. + + Besides, we add some additional features in this module. + 1. Automatically set `bias` of the conv layer. + 2. Spectral norm is supported. + 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only + supports zero and circular padding, and we add "reflect" padding mode. + + Args: + in_channels (int): Number of channels in the input feature map. + Same as that in ``nn._ConvNd``. + out_channels (int): Number of channels produced by the convolution. + Same as that in ``nn._ConvNd``. + kernel_size (int | tuple[int]): Size of the convolving kernel. + Same as that in ``nn._ConvNd``. + stride (int | tuple[int]): Stride of the convolution. + Same as that in ``nn._ConvNd``. + padding (int | tuple[int]): Zero-padding added to both sides of + the input. Same as that in ``nn._ConvNd``. + dilation (int | tuple[int]): Spacing between kernel elements. + Same as that in ``nn._ConvNd``. + groups (int): Number of blocked connections from input channels to + output channels. Same as that in ``nn._ConvNd``. + bias (bool | str): If specified as `auto`, it will be decided by the + norm_layer. Bias will be set as True if `norm_layer` is None, otherwise + False. Default: "auto". + conv_layer (nn.Module): Convolution layer. Default: None, + which means using conv2d. + norm_layer (nn.Module): Normalization layer. Default: None. + act_layer (nn.Module): Activation layer. Default: nn.ReLU. + inplace (bool): Whether to use inplace mode for activation. + Default: True. + with_spectral_norm (bool): Whether use spectral norm in conv module. + Default: False. + padding_mode (str): If the `padding_mode` has not been supported by + current `Conv2d` in PyTorch, we will use our own padding layer + instead. Currently, we support ['zeros', 'circular'] with official + implementation and ['reflect'] with our own implementation. + Default: 'zeros'. + order (tuple[str]): The order of conv/norm/activation layers. It is a + sequence of "conv", "norm" and "act". Common examples are + ("conv", "norm", "act") and ("act", "conv", "norm"). + Default: ('conv', 'norm', 'act'). + """ + + _abbr_ = "conv_block" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias="auto", + conv_layer=nn.Conv2d, + norm_layer=None, + act_layer=nn.ReLU, + inplace=True, + with_spectral_norm=False, + padding_mode="zeros", + order=("conv", "norm", "act"), + ): + super(ConvModule, self).__init__() + official_padding_mode = ["zeros", "circular"] + self.conv_layer = conv_layer + self.norm_layer = norm_layer + self.act_layer = act_layer + self.inplace = inplace + self.with_spectral_norm = with_spectral_norm + self.with_explicit_padding = padding_mode not in official_padding_mode + self.order = order + assert isinstance(self.order, tuple) and len(self.order) == 3 + assert set(order) == set(["conv", "norm", "act"]) + + self.with_norm = norm_layer is not None + self.with_activation = act_layer is not None + # if the conv layer is before a norm layer, bias is unnecessary. + if bias == "auto": + bias = not self.with_norm + self.with_bias = bias + + if self.with_explicit_padding: + if padding_mode == "zeros": + padding_layer = nn.ZeroPad2d + else: + raise AssertionError(f"Unsupported padding mode: {padding_mode}") + self.pad = padding_layer(padding) + + # reset padding to 0 for conv module + conv_padding = 0 if self.with_explicit_padding else padding + # build convolution layer + self.conv = self.conv_layer( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=conv_padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + # export the attributes of self.conv to a higher level for convenience + self.in_channels = self.conv.in_channels + self.out_channels = self.conv.out_channels + self.kernel_size = self.conv.kernel_size + self.stride = self.conv.stride + self.padding = padding + self.dilation = self.conv.dilation + self.transposed = self.conv.transposed + self.output_padding = self.conv.output_padding + self.groups = self.conv.groups + + if self.with_spectral_norm: + self.conv = nn.utils.spectral_norm(self.conv) + + # build normalization layers + if self.with_norm: + # norm layer is after conv layer + if order.index("norm") > order.index("conv"): + norm_channels = out_channels + else: + norm_channels = in_channels + norm = partial(norm_layer, num_features=norm_channels) + self.add_module("norm", norm) + if self.with_bias: + from torch.nnModules.batchnorm import _BatchNorm + from torch.nnModules.instancenorm import _InstanceNorm + + if isinstance(norm, (_BatchNorm, _InstanceNorm)): + warnings.warn("Unnecessary conv bias before batch/instance norm") + else: + self.norm_name = None + + # build activation layer + if self.with_activation: + # nn.Tanh has no 'inplace' argument + # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU) + if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)): + act_layer = partial(act_layer, inplace=inplace) + self.activate = act_layer() + + # Use msra init by default + self.init_weights() + + @property + def norm(self): + if self.norm_name: + return getattr(self, self.norm_name) + else: + return None + + def init_weights(self): + # 1. It is mainly for customized conv layers with their own + # initialization manners by calling their own ``init_weights()``, + # and we do not want ConvModule to override the initialization. + # 2. For customized conv layers without their own initialization + # manners (that is, they don't have their own ``init_weights()``) + # and PyTorch's conv layers, they will be initialized by + # this method with default ``kaiming_init``. + # Note: For PyTorch's conv layers, they will be overwritten by our + # initialization implementation using default ``kaiming_init``. + if not hasattr(self.conv, "init_weights"): + if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU): + nonlinearity = "leaky_relu" + a = 0.01 # XXX: default negative_slope + else: + nonlinearity = "relu" + a = 0 + if hasattr(self.conv, "weight") and self.conv.weight is not None: + nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity) + if hasattr(self.conv, "bias") and self.conv.bias is not None: + nn.init.constant_(self.conv.bias, 0) + if self.with_norm: + if hasattr(self.norm, "weight") and self.norm.weight is not None: + nn.init.constant_(self.norm.weight, 1) + if hasattr(self.norm, "bias") and self.norm.bias is not None: + nn.init.constant_(self.norm.bias, 0) + + def forward(self, x, activate=True, norm=True): + for layer in self.order: + if layer == "conv": + if self.with_explicit_padding: + x = self.pad(x) + x = self.conv(x) + elif layer == "norm" and norm and self.with_norm: + x = self.norm(x) + elif layer == "act" and activate and self.with_activation: + x = self.activate(x) + return x + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, align_corners=False): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + return x + + +class HeadDepth(nn.Module): + def __init__(self, features): + super(HeadDepth, self).__init__() + self.head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, x): + x = self.head(x) + return x + + +class ReassembleBlocks(nn.Module): + """ViTPostProcessBlock, process cls_token in ViT backbone output and + rearrange the feature vector to feature map. + Args: + in_channels (int): ViT feature channels. Default: 768. + out_channels (List): output channels of each stage. + Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + """ + + def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16): + super(ReassembleBlocks, self).__init__() + + assert readout_type in ["ignore", "add", "project"] + self.readout_type = readout_type + self.patch_size = patch_size + + self.projects = nn.ModuleList( + [ + ConvModule( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + act_layer=None, + ) + for out_channel in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + if self.readout_type == "project": + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU())) + + def forward(self, inputs): + assert isinstance(inputs, list) + out = [] + for i, x in enumerate(inputs): + assert len(x) == 2 + x, cls_token = x[0], x[1] + feature_shape = x.shape + if self.readout_type == "project": + x = x.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + x = x.permute(0, 2, 1).reshape(feature_shape) + elif self.readout_type == "add": + x = x.flatten(2) + cls_token.unsqueeze(-1) + x = x.reshape(feature_shape) + else: + pass + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + return out + + +class PreActResidualConvUnit(nn.Module): + """ResidualConvUnit, pre-activate residual unit. + Args: + in_channels (int): number of channels in the input feature map. + act_layer (nn.Module): activation layer. + norm_layer (nn.Module): norm layer. + stride (int): stride of the first block. Default: 1 + dilation (int): dilation rate for convs layers. Default: 1. + """ + + def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1): + super(PreActResidualConvUnit, self).__init__() + + self.conv1 = ConvModule( + in_channels, + in_channels, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + norm_layer=norm_layer, + act_layer=act_layer, + bias=False, + order=("act", "conv", "norm"), + ) + + self.conv2 = ConvModule( + in_channels, + in_channels, + 3, + padding=1, + norm_layer=norm_layer, + act_layer=act_layer, + bias=False, + order=("act", "conv", "norm"), + ) + + def forward(self, inputs): + inputs_ = inputs.clone() + x = self.conv1(inputs) + x = self.conv2(x) + return x + inputs_ + + +class FeatureFusionBlock(nn.Module): + """FeatureFusionBlock, merge feature map from different stages. + Args: + in_channels (int): Input channels. + act_layer (nn.Module): activation layer for ResidualConvUnit. + norm_layer (nn.Module): normalization layer. + expand (bool): Whether expand the channels in post process block. + Default: False. + align_corners (bool): align_corner setting for bilinear upsample. + Default: True. + """ + + def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True): + super(FeatureFusionBlock, self).__init__() + + self.in_channels = in_channels + self.expand = expand + self.align_corners = align_corners + + self.out_channels = in_channels + if self.expand: + self.out_channels = in_channels // 2 + + self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True) + + self.res_conv_unit1 = PreActResidualConvUnit( + in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer + ) + self.res_conv_unit2 = PreActResidualConvUnit( + in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer + ) + + def forward(self, *inputs): + x = inputs[0] + if len(inputs) == 2: + if x.shape != inputs[1].shape: + res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False) + else: + res = inputs[1] + x = x + self.res_conv_unit1(res) + x = self.res_conv_unit2(x) + x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners) + x = self.project(x) + return x + + +class DPTHead(DepthBaseDecodeHead): + """Vision Transformers for Dense Prediction. + This head is implemented of `DPT `_. + Args: + embed_dims (int): The embed dimension of the ViT backbone. + Default: 768. + post_process_channels (List): Out channels of post process conv + layers. Default: [96, 192, 384, 768]. + readout_type (str): Type of readout operation. Default: 'ignore'. + patch_size (int): The patch size. Default: 16. + expand_channels (bool): Whether expand the channels in post process + block. Default: False. + """ + + def __init__( + self, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + readout_type="ignore", + patch_size=16, + expand_channels=False, + **kwargs, + ): + super(DPTHead, self).__init__(**kwargs) + + self.in_channels = self.in_channels + self.expand_channels = expand_channels + self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size) + + self.post_process_channels = [ + channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels) + ] + self.convs = nn.ModuleList() + for channel in self.post_process_channels: + self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False)) + self.fusion_blocks = nn.ModuleList() + for _ in range(len(self.convs)): + self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer)) + self.fusion_blocks[0].res_conv_unit1 = None + self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer) + self.num_fusion_blocks = len(self.fusion_blocks) + self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers) + self.num_post_process_channels = len(self.post_process_channels) + assert self.num_fusion_blocks == self.num_reassemble_blocks + assert self.num_reassemble_blocks == self.num_post_process_channels + self.conv_depth = HeadDepth(self.channels) + + def forward(self, inputs, img_metas): + assert len(inputs) == self.num_reassemble_blocks + x = [inp for inp in inputs] + x = self.reassemble_blocks(x) + x = [self.convs[i](feature) for i, feature in enumerate(x)] + out = self.fusion_blocks[0](x[-1]) + for i in range(1, len(self.fusion_blocks)): + out = self.fusion_blocks[i](out, x[-(i + 1)]) + out = self.project(out) + out = self.depth_pred(out) + return out diff --git a/src/dinov2/hub/depth/encoder_decoder.py b/src/dinov2/hub/depth/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..eb29ced67957a336e763b0e7c90c0eeaea36fea8 --- /dev/null +++ b/src/dinov2/hub/depth/encoder_decoder.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ops import resize + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs + + +class DepthEncoderDecoder(nn.Module): + """Encoder Decoder depther. + + EncoderDecoder typically consists of backbone and decode_head. + """ + + def __init__(self, backbone, decode_head): + super(DepthEncoderDecoder, self).__init__() + + self.backbone = backbone + self.decode_head = decode_head + self.align_corners = self.decode_head.align_corners + + def extract_feat(self, img): + """Extract features from images.""" + return self.backbone(img) + + def encode_decode(self, img, img_metas, rescale=True, size=None): + """Encode images with backbone and decode into a depth estimation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + # crop the pred depth to the certain range. + out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth) + if rescale: + if size is None: + if img_metas is not None: + size = img_metas[0]["ori_shape"][:2] + else: + size = img.shape[2:] + out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs) + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + depth_pred = self.decode_head.forward_test(x, img_metas) + return depth_pred + + def forward_dummy(self, img): + """Dummy forward function.""" + depth = self.encode_decode(img, None) + + return depth + + def forward_train(self, img, img_metas, depth_gt, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + depth_gt (Tensor): Depth gt + used if the architecture supports depth estimation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + # the last of x saves the info from neck + loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs) + + losses.update(loss_decode) + + return losses + + def whole_inference(self, img, img_meta, rescale, size=None): + """Inference with full image.""" + return self.encode_decode(img, img_meta, rescale, size=size) + + def slide_inference(self, img, img_meta, rescale, stride, crop_size): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = stride + h_crop, w_crop = crop_size + batch_size, _, h_img, w_img = img.size() + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, 1, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + depth_pred = self.encode_decode(crop_img, img_meta, rescale) + preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + return preds + + def inference(self, img, img_meta, rescale, size=None, mode="whole"): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `depth/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output depth map. + """ + + assert mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if mode == "slide": + depth_pred = self.slide_inference(img, img_meta, rescale) + else: + depth_pred = self.whole_inference(img, img_meta, rescale, size=size) + output = depth_pred + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + depth_pred = self.inference(img, img_meta, rescale) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + depth_pred = depth_pred.unsqueeze(0) + return depth_pred + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented depth logit inplace + depth_pred = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:]) + depth_pred += cur_depth_pred + depth_pred /= len(imgs) + depth_pred = depth_pred.cpu().numpy() + # unravel batch dim + depth_pred = list(depth_pred) + return depth_pred + + def forward_test(self, imgs, img_metas, **kwargs): + """ + Args: + imgs (List[Tensor]): the outer list indicates test-time + augmentations and inner Tensor should have a shape NxCxHxW, + which contains all images in the batch. + img_metas (List[List[dict]]): the outer list indicates test-time + augs (multiscale, flip, etc.) and the inner list indicates + images in a batch. + """ + for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]: + if not isinstance(var, list): + raise TypeError(f"{name} must be a list, but got " f"{type(var)}") + num_augs = len(imgs) + if num_augs != len(img_metas): + raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})") + # all images in the same aug batch all of the same ori_shape and pad + # shape + for img_meta in img_metas: + ori_shapes = [_["ori_shape"] for _ in img_meta] + assert all(shape == ori_shapes[0] for shape in ori_shapes) + img_shapes = [_["img_shape"] for _ in img_meta] + assert all(shape == img_shapes[0] for shape in img_shapes) + pad_shapes = [_["pad_shape"] for _ in img_meta] + assert all(shape == pad_shapes[0] for shape in pad_shapes) + + if num_augs == 1: + return self.simple_test(imgs[0], img_metas[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + + def forward(self, img, img_metas, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: + return self.forward_test(img, img_metas, **kwargs) + + def train_step(self, data_batch, optimizer, **kwargs): + """The iteration step during training. + + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, + ``num_samples``. + ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + ``log_vars`` contains all the variables to be sent to the + logger. + ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data_batch) + + # split losses and images + real_losses = {} + log_imgs = {} + for k, v in losses.items(): + if "img" in k: + log_imgs[k] = v + else: + real_losses[k] = v + + loss, log_vars = self._parse_losses(real_losses) + + outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs) + + return outputs + + def val_step(self, data_batch, **kwargs): + """The iteration step during validation. + + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + output = self(**data_batch, **kwargs) + return output + + @staticmethod + def _parse_losses(losses): + import torch.distributed as dist + + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor + which may be a weighted sum of all losses, log_vars contains + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError(f"{loss_name} is not a tensor or list of tensors") + + loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key) + + log_vars["loss"] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars diff --git a/src/dinov2/hub/depth/ops.py b/src/dinov2/hub/depth/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..15880ee0cb7652d4b41c489b927bf6a156b40e5e --- /dev/null +++ b/src/dinov2/hub/depth/ops.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch.nn.functional as F + + +def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/src/dinov2/hub/depthers.py b/src/dinov2/hub/depthers.py new file mode 100644 index 0000000000000000000000000000000000000000..f88b7e9a41056594e3b3e66107feee98bffab820 --- /dev/null +++ b/src/dinov2/hub/depthers.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from functools import partial +from typing import Optional, Tuple, Union + +import torch + +from .backbones import _make_dinov2_model +from .depth import BNHead, DepthEncoderDecoder, DPTHead +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding + + +class Weights(Enum): + NYU = "NYU" + KITTI = "KITTI" + + +def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]: + if not pretrained: # Default + return (0.001, 10.0) + + # Pretrained, set according to the training dataset for the provided weights + if weights == Weights.KITTI: + return (0.001, 80.0) + + if weights == Weights.NYU: + return (0.001, 10.0) + + return (0.001, 10.0) + + +def _make_dinov2_linear_depth_head( + *, + embed_dim: int, + layers: int, + min_depth: float, + max_depth: float, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + + if layers == 1: + in_index = [0] + else: + assert layers == 4 + in_index = [0, 1, 2, 3] + + return BNHead( + classify=True, + n_bins=256, + bins_strategy="UD", + norm_strategy="linear", + upsample=4, + in_channels=[embed_dim] * len(in_index), + in_index=in_index, + input_transform="resize_concat", + channels=embed_dim * len(in_index) * 2, + align_corners=False, + min_depth=0.001, + max_depth=80, + loss_decode=(), + ) + + +def _make_dinov2_linear_depther( + *, + arch_name: str = "vit_large", + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.NYU, + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + if layers not in (1, 4): + raise AssertionError(f"Unsupported number of layers: {layers}") + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + if depth_range is None: + depth_range = _get_depth_range(pretrained, weights) + min_depth, max_depth = depth_range + + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + embed_dim = backbone.embed_dim + patch_size = backbone.patch_size + model_name = _make_dinov2_model_name(arch_name, patch_size) + linear_depth_head = _make_dinov2_linear_depth_head( + embed_dim=embed_dim, + layers=layers, + min_depth=min_depth, + max_depth=max_depth, + ) + + layer_count = { + "vit_small": 12, + "vit_base": 12, + "vit_large": 24, + "vit_giant2": 40, + }[arch_name] + + if layers == 4: + out_index = { + "vit_small": [2, 5, 8, 11], + "vit_base": [2, 5, 8, 11], + "vit_large": [4, 11, 17, 23], + "vit_giant2": [9, 19, 29, 39], + }[arch_name] + else: + assert layers == 1 + out_index = [layer_count - 1] + + model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head) + model.backbone.forward = partial( + backbone.get_intermediate_layers, + n=out_index, + reshape=True, + return_class_token=True, + norm=False, + ) + model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0])) + + if pretrained: + layers_str = str(layers) if layers == 4 else "" + weights_str = weights.value.lower() + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth" + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs + ) + + +def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_linear_depther( + arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) + + +def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float): + return DPTHead( + in_channels=[embed_dim] * 4, + channels=256, + embed_dims=embed_dim, + post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)], + readout_type="project", + min_depth=min_depth, + max_depth=max_depth, + loss_decode=(), + ) + + +def _make_dinov2_dpt_depther( + *, + arch_name: str = "vit_large", + pretrained: bool = True, + weights: Union[Weights, str] = Weights.NYU, + depth_range: Optional[Tuple[float, float]] = None, + **kwargs, +): + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + if depth_range is None: + depth_range = _get_depth_range(pretrained, weights) + min_depth, max_depth = depth_range + + backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + + model_name = _make_dinov2_model_name(arch_name, backbone.patch_size) + dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth) + + out_index = { + "vit_small": [2, 5, 8, 11], + "vit_base": [2, 5, 8, 11], + "vit_large": [4, 11, 17, 23], + "vit_giant2": [9, 19, 29, 39], + }[arch_name] + + model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head) + model.backbone.forward = partial( + backbone.get_intermediate_layers, + n=out_index, + reshape=True, + return_class_token=True, + norm=False, + ) + model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0])) + + if pretrained: + weights_str = weights.value.lower() + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth" + checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + model.load_state_dict(state_dict, strict=False) + + return model + + +def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs): + return _make_dinov2_dpt_depther( + arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + ) diff --git a/src/dinov2/hub/utils.py b/src/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/src/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/src/dinov2/layers/__init__.py b/src/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3317a482f21ee3e926958364d24ab2185cdf07da --- /dev/null +++ b/src/dinov2/layers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .block_vis import NestedTensorBlock as NestedTensorBlockVis +from .attention import MemEffAttention diff --git a/src/dinov2/layers/__pycache__/__init__.cpython-310.pyc b/src/dinov2/layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa5e1b01960d63986c34b6fe11ea206a45568d20 Binary files /dev/null and b/src/dinov2/layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/dinov2/layers/__pycache__/attention.cpython-310.pyc b/src/dinov2/layers/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..961b432480e0fdc73bcbe7b8916bf1005d4e8e3d Binary files /dev/null and b/src/dinov2/layers/__pycache__/attention.cpython-310.pyc differ diff --git a/src/dinov2/layers/__pycache__/block.cpython-310.pyc b/src/dinov2/layers/__pycache__/block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa6197a3a7082383200a731a8c3358ae1cb5f1ef Binary files /dev/null and b/src/dinov2/layers/__pycache__/block.cpython-310.pyc differ diff --git a/src/dinov2/layers/__pycache__/block_vis.cpython-310.pyc b/src/dinov2/layers/__pycache__/block_vis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce75e5291a276103e9aa81b7533d584c9f403961 Binary files /dev/null and b/src/dinov2/layers/__pycache__/block_vis.cpython-310.pyc differ diff --git a/src/dinov2/layers/__pycache__/dino_head.cpython-310.pyc b/src/dinov2/layers/__pycache__/dino_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dcd10e0b15bb54dff6bcd00b0f4d64631a7a963 Binary files /dev/null and b/src/dinov2/layers/__pycache__/dino_head.cpython-310.pyc differ diff --git a/src/dinov2/layers/__pycache__/drop_path.cpython-310.pyc b/src/dinov2/layers/__pycache__/drop_path.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef09d4ea647d5ed86f7183219dfad8f4813bfd7b Binary files /dev/null and b/src/dinov2/layers/__pycache__/drop_path.cpython-310.pyc differ diff --git a/src/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc b/src/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1719e77966795f907549df4fa1f3f73f93934fde Binary files /dev/null and b/src/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc differ diff --git a/src/dinov2/layers/__pycache__/mlp.cpython-310.pyc b/src/dinov2/layers/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe6952b2aae2bad634d8357224c9b9768f05f672 Binary files /dev/null and b/src/dinov2/layers/__pycache__/mlp.cpython-310.pyc differ diff --git a/src/dinov2/layers/__pycache__/patch_embed.cpython-310.pyc b/src/dinov2/layers/__pycache__/patch_embed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1037af46c5d11036c74e3234232e601e85565dcd Binary files /dev/null and b/src/dinov2/layers/__pycache__/patch_embed.cpython-310.pyc differ diff --git a/src/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc b/src/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..502838456d29b5b3242ac800cf97c7e449167d32 Binary files /dev/null and b/src/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc differ diff --git a/src/dinov2/layers/attention.py b/src/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ece26a0d24c29359fcb1c3bc78aa731ff10adf --- /dev/null +++ b/src/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not False: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/src/dinov2/layers/block.py b/src/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff4fded00243b70c6690bb367026013e9e30b12 --- /dev/null +++ b/src/dinov2/layers/block.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward_2(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + assert isinstance(x_or_x_list, torch.Tensor), "Expected a torch.Tensor" + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/src/dinov2/layers/block_vis.py b/src/dinov2/layers/block_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..61d510496992549439a2193a77ec28064da5d2a2 --- /dev/null +++ b/src/dinov2/layers/block_vis.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, return_attention=False) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + # Add this 2 lines + if return_attention: + return self.attn(self.norm1(x), return_attn=True) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list, return_attention=False): + if isinstance(x_or_x_list, Tensor): + # Change the following line + # return super().forward(x_or_x_list) + return super().forward(x_or_x_list, return_attention) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError + diff --git a/src/dinov2/layers/dino_head.py b/src/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/src/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/src/dinov2/layers/drop_path.py b/src/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/src/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/src/dinov2/layers/layer_scale.py b/src/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/src/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/src/dinov2/layers/mlp.py b/src/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/src/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/src/dinov2/layers/patch_embed.py b/src/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/src/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/src/dinov2/layers/swiglu_ffn.py b/src/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5e9dafa4592a408f6874d54853e8f60db5c41f74 --- /dev/null +++ b/src/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (SwiGLU)") + else: + warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/src/dinov2/logging/__init__.py b/src/dinov2/logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04a7f02204316d4d1ef38bf6080dae3d66241c25 --- /dev/null +++ b/src/dinov2/logging/__init__.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import functools +import logging +import os +import sys +from typing import Optional + +import dinov2.distributed as distributed +from .helpers import MetricLogger, SmoothedValue + + +# So that calling _configure_logger multiple times won't add many handlers +@functools.lru_cache() +def _configure_logger( + name: Optional[str] = None, + *, + level: int = logging.DEBUG, + output: Optional[str] = None, +): + """ + Configure a logger. + + Adapted from Detectron2. + + Args: + name: The name of the logger to configure. + level: The logging level to use. + output: A file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + + Returns: + The configured logger. + """ + + logger = logging.getLogger(name) + logger.setLevel(level) + logger.propagate = False + + # Loosely match Google glog format: + # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg + # but use a shorter timestamp and include the logger name: + # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg + fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] " + fmt_message = "%(message)s" + fmt = fmt_prefix + fmt_message + datefmt = "%Y%m%d %H:%M:%S" + formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) + + # stdout logging for main worker only + if distributed.is_main_process(): + handler = logging.StreamHandler(stream=sys.stdout) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + logger.addHandler(handler) + + # file logging for all workers + if output: + if os.path.splitext(output)[-1] in (".txt", ".log"): + filename = output + else: + filename = os.path.join(output, "logs", "log.txt") + + if not distributed.is_main_process(): + global_rank = distributed.get_global_rank() + filename = filename + ".rank{}".format(global_rank) + + os.makedirs(os.path.dirname(filename), exist_ok=True) + + handler = logging.StreamHandler(open(filename, "a")) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + logger.addHandler(handler) + + return logger + + +def setup_logging( + output: Optional[str] = None, + *, + name: Optional[str] = None, + level: int = logging.DEBUG, + capture_warnings: bool = True, +) -> None: + """ + Setup logging. + + Args: + output: A file name or a directory to save log files. If None, log + files will not be saved. If output ends with ".txt" or ".log", it + is assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name: The name of the logger to configure, by default the root logger. + level: The logging level to use. + capture_warnings: Whether warnings should be captured as logs. + """ + logging.captureWarnings(capture_warnings) + _configure_logger(name, level=level, output=output) diff --git a/src/dinov2/logging/helpers.py b/src/dinov2/logging/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..c6e70bb15505cbbc4c4732b069ee919bf921a74f --- /dev/null +++ b/src/dinov2/logging/helpers.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict, deque +import datetime +import json +import logging +import time + +import torch + +import dinov2.distributed as distributed + + +logger = logging.getLogger("dinov2") + + +class MetricLogger(object): + def __init__(self, delimiter="\t", output_file=None): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + self.output_file = output_file + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def dump_in_output_file(self, iteration, iter_time, data_time): + if self.output_file is None or not distributed.is_main_process(): + return + dict_to_dump = dict( + iteration=iteration, + iter_time=iter_time, + data_time=data_time, + ) + dict_to_dump.update({k: v.median for k, v in self.meters.items()}) + with open(self.output_file, "a") as f: + f.write(json.dumps(dict_to_dump) + "\n") + pass + + def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0): + i = start_iteration + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.6f}") + data_time = SmoothedValue(fmt="{avg:.6f}") + + if n_iterations is None: + n_iterations = len(iterable) + + space_fmt = ":" + str(len(str(n_iterations))) + "d" + + log_list = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_list += ["max mem: {memory:.0f}"] + + log_msg = self.delimiter.join(log_list) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == n_iterations - 1: + self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg) + eta_seconds = iter_time.global_avg * (n_iterations - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + logger.info( + log_msg.format( + i, + n_iterations, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + logger.info( + log_msg.format( + i, + n_iterations, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + if i >= n_iterations: + break + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / n_iterations)) + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, num=1): + self.deque.append(value) + self.count += num + self.total += value * num + + def synchronize_between_processes(self): + """ + Distributed synchronization of the metric + Warning: does not synchronize the deque! + """ + if not distributed.is_enabled(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + torch.distributed.barrier() + torch.distributed.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) diff --git a/src/dinov2/loss/__init__.py b/src/dinov2/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6b0115b74edbd74b324c9056a57fade363c58fd --- /dev/null +++ b/src/dinov2/loss/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_clstoken_loss import DINOLoss +from .ibot_patch_loss import iBOTPatchLoss +from .koleo_loss import KoLeoLoss diff --git a/src/dinov2/loss/dino_clstoken_loss.py b/src/dinov2/loss/dino_clstoken_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c31808e36e6c38ee6dae13ba0443bf1946242117 --- /dev/null +++ b/src/dinov2/loss/dino_clstoken_loss.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + + +class DINOLoss(nn.Module): + def __init__( + self, + out_dim, + student_temp=0.1, + center_momentum=0.9, + ): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.zeros(1, out_dim)) + self.updated = True + self.reduce_handle = None + self.len_teacher_output = None + self.async_batch_center = None + + @torch.no_grad() + def softmax_center_teacher(self, teacher_output, teacher_temp): + self.apply_center_update() + # teacher centering and sharpening + return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1) + + @torch.no_grad() + def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): + teacher_output = teacher_output.float() + world_size = dist.get_world_size() if dist.is_initialized() else 1 + Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper + B = Q.shape[1] * world_size # number of samples to assign + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q) + Q /= sum_Q + + for it in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + def forward(self, student_output_list, teacher_out_softmaxed_centered_list): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + """ + # TODO: Use cross_entropy_distribution here + total_loss = 0 + for s in student_output_list: + lsm = F.log_softmax(s / self.student_temp, dim=-1) + for t in teacher_out_softmaxed_centered_list: + loss = torch.sum(t * lsm, dim=-1) + total_loss -= loss.mean() + return total_loss + + @torch.no_grad() + def update_center(self, teacher_output): + self.reduce_center_update(teacher_output) + + @torch.no_grad() + def reduce_center_update(self, teacher_output): + self.updated = False + self.len_teacher_output = len(teacher_output) + self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + if dist.is_initialized(): + self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) + + @torch.no_grad() + def apply_center_update(self): + if self.updated is False: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if self.reduce_handle is not None: + self.reduce_handle.wait() + _t = self.async_batch_center / (self.len_teacher_output * world_size) + + self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) + + self.updated = True diff --git a/src/dinov2/loss/ibot_patch_loss.py b/src/dinov2/loss/ibot_patch_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..6732cda0c311c69f193669ebc950fc8665871442 --- /dev/null +++ b/src/dinov2/loss/ibot_patch_loss.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + +import logging + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import cross_entropy + + def lossfunc(t, s, temp): + s = s.float() + t = t.float() + if s.ndim == 2: + return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0) + elif s.ndim == 3: + return -cross_entropy(s, t, temp, bw_inplace=True) + +except ImportError: + + def lossfunc(t, s, temp): + return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1) + + +class iBOTPatchLoss(nn.Module): + def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.zeros(1, 1, patch_out_dim)) + self.updated = True + self.reduce_handle = None + self.len_teacher_patch_tokens = None + self.async_batch_center = None + + @torch.no_grad() + def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp): + self.apply_center_update() + # teacher centering and sharpening + # + # WARNING: + # as self.center is a float32, everything gets casted to float32 afterwards + # + # teacher_patch_tokens = teacher_patch_tokens.float() + # return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1) + + return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1) + + # this is experimental, keep everything in float16 and let's see what happens: + # return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1) + + @torch.no_grad() + def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3): + teacher_output = teacher_output.float() + # world_size = dist.get_world_size() if dist.is_initialized() else 1 + Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper + # B = Q.shape[1] * world_size # number of samples to assign + B = n_masked_patches_tensor + dist.all_reduce(B) + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q) + Q /= sum_Q + + for it in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + student_patch_tokens: (B, N, D) tensor + teacher_patch_tokens: (B, N, D) tensor + student_masks_flat: (B, N) tensor + """ + t = teacher_patch_tokens + s = student_patch_tokens + loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) + loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0) + return -loss.mean() + + def forward_masked( + self, + student_patch_tokens_masked, + teacher_patch_tokens_masked, + student_masks_flat, + n_masked_patches=None, + masks_weight=None, + ): + t = teacher_patch_tokens_masked + s = student_patch_tokens_masked + # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) + loss = lossfunc(t, s, self.student_temp) + if masks_weight is None: + masks_weight = ( + (1 / student_masks_flat.sum(-1).clamp(min=1.0)) + .unsqueeze(-1) + .expand_as(student_masks_flat)[student_masks_flat] + ) + if n_masked_patches is not None: + loss = loss[:n_masked_patches] + loss = loss * masks_weight + return -loss.sum() / student_masks_flat.shape[0] + + @torch.no_grad() + def update_center(self, teacher_patch_tokens): + self.reduce_center_update(teacher_patch_tokens) + + @torch.no_grad() + def reduce_center_update(self, teacher_patch_tokens): + self.updated = False + self.len_teacher_patch_tokens = len(teacher_patch_tokens) + self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True) + if dist.is_initialized(): + self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) + + @torch.no_grad() + def apply_center_update(self): + if self.updated is False: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if self.reduce_handle is not None: + self.reduce_handle.wait() + _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size) + + self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) + + self.updated = True diff --git a/src/dinov2/loss/koleo_loss.py b/src/dinov2/loss/koleo_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b5cbcd91e0fc0b857f477b0910f957f02a6c4335 --- /dev/null +++ b/src/dinov2/loss/koleo_loss.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# import torch.distributed as dist + + +logger = logging.getLogger("dinov2") + + +class KoLeoLoss(nn.Module): + """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search""" + + def __init__(self): + super().__init__() + self.pdist = nn.PairwiseDistance(2, eps=1e-8) + + def pairwise_NNs_inner(self, x): + """ + Pairwise nearest neighbors for L2-normalized vectors. + Uses Torch rather than Faiss to remain on GPU. + """ + # parwise dot products (= inverse distance) + dots = torch.mm(x, x.t()) + n = x.shape[0] + dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1 + # max inner prod -> min distance + _, I = torch.max(dots, dim=1) # noqa: E741 + return I + + def forward(self, student_output, eps=1e-8): + """ + Args: + student_output (BxD): backbone output of student + """ + with torch.cuda.amp.autocast(enabled=False): + student_output = F.normalize(student_output, eps=eps, p=2, dim=-1) + I = self.pairwise_NNs_inner(student_output) # noqa: E741 + distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B + loss = -torch.log(distances + eps).mean() + return loss diff --git a/src/dinov2/models/__init__.py b/src/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265 --- /dev/null +++ b/src/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/src/dinov2/models/__pycache__/__init__.cpython-310.pyc b/src/dinov2/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4eae53201342ecaffb878e843d6aeee20104b0b Binary files /dev/null and b/src/dinov2/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/dinov2/models/__pycache__/vision_transformer.cpython-310.pyc b/src/dinov2/models/__pycache__/vision_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec138eb4cfe334e1d69663f24bf51373eeb82f8f Binary files /dev/null and b/src/dinov2/models/__pycache__/vision_transformer.cpython-310.pyc differ diff --git a/src/dinov2/models/__pycache__/vision_transformer_vis.cpython-310.pyc b/src/dinov2/models/__pycache__/vision_transformer_vis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d996dc8399707990a3654ed3bd143e9fd90440c Binary files /dev/null and b/src/dinov2/models/__pycache__/vision_transformer_vis.cpython-310.pyc differ diff --git a/src/dinov2/models/vision_transformer.py b/src/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d737be71dd3c58f0fea136a7c1c104e4254699c4 --- /dev/null +++ b/src/dinov2/models/vision_transformer.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from src.dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=518, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None, prompt=None): + # print("prompt", prompt) + B, nc, w, h = x.shape + x = self.patch_embed(x) + + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + x = x + self.interpolate_pos_encoding(x, w, h) + + # if prompt is not None: + # x = torch.cat([x, prompt], dim=1) + + if prompt is not None: + x = torch.cat( + ( + x[:, :1], + prompt, + x[:, 1:] + ), + dim=1 + ) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list, prompt_list): + x = [self.prepare_tokens_with_masks(x, masks, prompt) for x, masks, prompt in zip(x_list, masks_list, prompt_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None, prompt=None): + # print("forward_features prompt: ", prompt) + if isinstance(x, list): + return self.forward_features_list(x, masks, prompt) + + x = self.prepare_tokens_with_masks(x, masks, prompt) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, x, is_training=False, prompt=None): + ret = self.forward_features(x=x, prompt=prompt) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, init_values=1.0, block_chunks=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + init_values=init_values, # for layerscale: None or 0 => no layerscale + block_chunks=block_chunks, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/src/dinov2/models/vision_transformer_vis.py b/src/dinov2/models/vision_transformer_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..e567dd6103a858fb36042fc5f5d31062fd43825c --- /dev/null +++ b/src/dinov2/models/vision_transformer_vis.py @@ -0,0 +1,429 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from src.dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlockVis as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=518, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def get_last_self_attention(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + # Run through model, at the last block just return the attention. + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + return blk(x, return_attention=True) + + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None, prompt=None): + # print("prompt", prompt) + B, nc, w, h = x.shape + x = self.patch_embed(x) + + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + x = x + self.interpolate_pos_encoding(x, w, h) + + # if prompt is not None: + # x = torch.cat([x, prompt], dim=1) + + if prompt is not None: + x = torch.cat( + ( + x[:, :1], + prompt, + x[:, 1:] + ), + dim=1 + ) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list, prompt_list): + x = [self.prepare_tokens_with_masks(x, masks, prompt) for x, masks, prompt in zip(x_list, masks_list, prompt_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None, prompt=None): + # print("forward_features prompt: ", prompt) + if isinstance(x, list): + return self.forward_features_list(x, masks, prompt) + + x = self.prepare_tokens_with_masks(x, masks, prompt) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, init_values=1.0, block_chunks=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + init_values=init_values, # for layerscale: None or 0 => no layerscale + block_chunks=block_chunks, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/src/dinov2/run/__init__.py b/src/dinov2/run/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/src/dinov2/run/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/dinov2/run/eval/knn.py b/src/dinov2/run/eval/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..d11918445cdfe415fe58ac8b3ad0bf29702e3457 --- /dev/null +++ b/src/dinov2/run/eval/knn.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.eval.knn import get_args_parser as get_knn_args_parser +from dinov2.logging import setup_logging +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.eval.knn import main as knn_main + + self._setup_args() + knn_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 k-NN evaluation" + knn_args_parser = get_knn_args_parser(add_help=False) + parents = [knn_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:knn") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/dinov2/run/eval/linear.py b/src/dinov2/run/eval/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..e1dc3293e88512a5cf885ab775dc08e01aed6724 --- /dev/null +++ b/src/dinov2/run/eval/linear.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.eval.linear import get_args_parser as get_linear_args_parser +from dinov2.logging import setup_logging +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.eval.linear import main as linear_main + + self._setup_args() + linear_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 linear evaluation" + linear_args_parser = get_linear_args_parser(add_help=False) + parents = [linear_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:linear") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/dinov2/run/eval/log_regression.py b/src/dinov2/run/eval/log_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf02181122de72cfa463ef38494967219df9cf3 --- /dev/null +++ b/src/dinov2/run/eval/log_regression.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.eval.log_regression import get_args_parser as get_log_regression_args_parser +from dinov2.logging import setup_logging +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.eval.log_regression import main as log_regression_main + + self._setup_args() + log_regression_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 logistic evaluation" + log_regression_args_parser = get_log_regression_args_parser(add_help=False) + parents = [log_regression_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:logreg") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/dinov2/run/submit.py b/src/dinov2/run/submit.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1f718e704cf9a48913422404c25a7fcc50e738 --- /dev/null +++ b/src/dinov2/run/submit.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +from pathlib import Path +from typing import List, Optional + +import submitit + +from dinov2.utils.cluster import ( + get_slurm_executor_parameters, + get_slurm_partition, + get_user_checkpoint_path, +) + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +) -> argparse.ArgumentParser: + parents = parents or [] + slurm_partition = get_slurm_partition() + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--ngpus", + "--gpus", + "--gpus-per-node", + default=8, + type=int, + help="Number of GPUs to request on each node", + ) + parser.add_argument( + "--nodes", + "--nnodes", + default=1, + type=int, + help="Number of nodes to request", + ) + parser.add_argument( + "--timeout", + default=2800, + type=int, + help="Duration of the job", + ) + parser.add_argument( + "--partition", + default=slurm_partition, + type=str, + help="Partition where to submit", + ) + parser.add_argument( + "--use-volta32", + action="store_true", + help="Request V100-32GB GPUs", + ) + parser.add_argument( + "--comment", + default="", + type=str, + help="Comment to pass to scheduler, e.g. priority message", + ) + parser.add_argument( + "--exclude", + default="", + type=str, + help="Nodes to exclude", + ) + return parser + + +def get_shared_folder() -> Path: + user_checkpoint_path = get_user_checkpoint_path() + if user_checkpoint_path is None: + raise RuntimeError("Path to user checkpoint cannot be determined") + path = user_checkpoint_path / "experiments" + path.mkdir(exist_ok=True) + return path + + +def submit_jobs(task_class, args, name: str): + if not args.output_dir: + args.output_dir = str(get_shared_folder() / "%j") + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) + + kwargs = {} + if args.use_volta32: + kwargs["slurm_constraint"] = "volta32gb" + if args.comment: + kwargs["slurm_comment"] = args.comment + if args.exclude: + kwargs["slurm_exclude"] = args.exclude + + executor_params = get_slurm_executor_parameters( + nodes=args.nodes, + num_gpus_per_node=args.ngpus, + timeout_min=args.timeout, # max is 60 * 72 + slurm_signal_delay_s=120, + slurm_partition=args.partition, + **kwargs, + ) + executor.update_parameters(name=name, **executor_params) + + task = task_class(args) + job = executor.submit(task) + + logger.info(f"Submitted job_id: {job.job_id}") + str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id)) + logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}") diff --git a/src/dinov2/run/train/train.py b/src/dinov2/run/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..c2366e9bf79765e6abcd70dda6b43f31cb7093eb --- /dev/null +++ b/src/dinov2/run/train/train.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from dinov2.logging import setup_logging +from dinov2.train import get_args_parser as get_train_args_parser +from dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Trainer(object): + def __init__(self, args): + self.args = args + + def __call__(self): + from dinov2.train import main as train_main + + self._setup_args() + train_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 training" + train_args_parser = get_train_args_parser(add_help=False) + parents = [train_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Trainer, args, name="dinov2:train") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/dinov2/train/__init__.py b/src/dinov2/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1752922d04fff0112eb7796be28ff6b68c6073 --- /dev/null +++ b/src/dinov2/train/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .train import get_args_parser, main +from .ssl_meta_arch import SSLMetaArch diff --git a/src/dinov2/train/ssl_meta_arch.py b/src/dinov2/train/ssl_meta_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccf15e904ebeb6134dfb4f5c99da4fc8d41b8e4 --- /dev/null +++ b/src/dinov2/train/ssl_meta_arch.py @@ -0,0 +1,400 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from functools import partial +import logging + +import torch +from torch import nn + +from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss +from dinov2.models import build_model_from_cfg +from dinov2.layers import DINOHead +from dinov2.utils.utils import has_batchnorms +from dinov2.utils.param_groups import get_params_groups_with_decay, fuse_params_groups +from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, reshard_fsdp_model + +from dinov2.models.vision_transformer import BlockChunk + + +try: + from xformers.ops import fmha +except ImportError: + raise AssertionError("xFormers is required for training") + + +logger = logging.getLogger("dinov2") + + +class SSLMetaArch(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.fp16_scaler = ShardedGradScaler() if cfg.compute_precision.grad_scaler else None + + student_model_dict = dict() + teacher_model_dict = dict() + + student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) + student_model_dict["backbone"] = student_backbone + teacher_model_dict["backbone"] = teacher_backbone + logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}") + + if cfg.student.pretrained_weights: + chkpt = torch.load(cfg.student.pretrained_weights) + logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}") + student_backbone.load_state_dict(chkpt["model"], strict=False) + + self.embed_dim = embed_dim + self.dino_out_dim = cfg.dino.head_n_prototypes + + self.do_dino = cfg.dino.loss_weight > 0 + self.do_koleo = cfg.dino.koleo_loss_weight > 0 + self.do_ibot = cfg.ibot.loss_weight > 0 + self.ibot_separate_head = cfg.ibot.separate_head + + logger.info("OPTIONS -- DINO") + if self.do_dino: + logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}") + logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}") + logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}") + logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}") + self.dino_loss_weight = cfg.dino.loss_weight + dino_head = partial( + DINOHead, + in_dim=embed_dim, + out_dim=cfg.dino.head_n_prototypes, + hidden_dim=cfg.dino.head_hidden_dim, + bottleneck_dim=cfg.dino.head_bottleneck_dim, + nlayers=cfg.dino.head_nlayers, + ) + self.dino_loss = DINOLoss(self.dino_out_dim) + if self.do_koleo: + logger.info("OPTIONS -- DINO -- applying KOLEO regularization") + self.koleo_loss = KoLeoLoss() + + else: + logger.info("OPTIONS -- DINO -- not using DINO") + + if self.do_dino or self.do_ibot: + student_model_dict["dino_head"] = dino_head() + teacher_model_dict["dino_head"] = dino_head() + + logger.info("OPTIONS -- IBOT") + logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}") + if self.do_ibot: + self.ibot_loss_weight = cfg.ibot.loss_weight + assert max(cfg.ibot.mask_ratio_min_max) > 0, "please provide a positive mask ratio tuple for ibot" + assert cfg.ibot.mask_sample_probability > 0, "please provide a positive mask probability for ibot" + self.ibot_out_dim = cfg.ibot.head_n_prototypes if self.ibot_separate_head else cfg.dino.head_n_prototypes + self.ibot_patch_loss = iBOTPatchLoss(self.ibot_out_dim) + if self.ibot_separate_head: + logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") + logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}") + logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}") + logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}") + ibot_head = partial( + DINOHead, + in_dim=embed_dim, + out_dim=cfg.ibot.head_n_prototypes, + hidden_dim=cfg.ibot.head_hidden_dim, + bottleneck_dim=cfg.ibot.head_bottleneck_dim, + nlayers=cfg.ibot.head_nlayers, + ) + student_model_dict["ibot_head"] = ibot_head() + teacher_model_dict["ibot_head"] = ibot_head() + else: + logger.info("OPTIONS -- IBOT -- head shared with DINO") + + self.need_to_synchronize_fsdp_streams = True + + self.student = nn.ModuleDict(student_model_dict) + self.teacher = nn.ModuleDict(teacher_model_dict) + + # there is no backpropagation through the teacher, so no need for gradients + for p in self.teacher.parameters(): + p.requires_grad = False + logger.info(f"Student and Teacher are built: they are both {cfg.student.arch} network.") + + def forward(self, inputs): + raise NotImplementedError + + def backprop_loss(self, loss): + if self.fp16_scaler is not None: + self.fp16_scaler.scale(loss).backward() + else: + loss.backward() + + def forward_backward(self, images, teacher_temp): + n_global_crops = 2 + assert n_global_crops == 2 + n_local_crops = self.cfg.crops.local_crops_number + + global_crops = images["collated_global_crops"].cuda(non_blocking=True) + local_crops = images["collated_local_crops"].cuda(non_blocking=True) + + masks = images["collated_masks"].cuda(non_blocking=True) + mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True) + n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True) + n_masked_patches = mask_indices_list.shape[0] + upperbound = images["upperbound"] + masks_weight = images["masks_weight"].cuda(non_blocking=True) + + n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1) + n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops + + do_dino = self.do_dino + do_ibot = self.do_ibot + + # loss scales + ibot_loss_scale = 1.0 / n_global_crops + + # teacher output + @torch.no_grad() + def get_teacher_output(): + x, n_global_crops_teacher = global_crops, n_global_crops + teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True) + teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"] + teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops_teacher) + # watch out: these are chunked and cat'd in reverse so A is matched to B in the global crops dino loss + teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0])) + ibot_teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"] + _dim = ibot_teacher_patch_tokens.shape[-1] + n_cls_tokens = teacher_cls_tokens.shape[0] + + if do_ibot and not self.ibot_separate_head: + buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound + n_cls_tokens, _dim) + buffer_tensor_teacher[:n_cls_tokens].copy_(teacher_cls_tokens) + torch.index_select( + ibot_teacher_patch_tokens.flatten(0, 1), + dim=0, + index=mask_indices_list, + out=buffer_tensor_teacher[n_cls_tokens : n_cls_tokens + n_masked_patches], + ) + tokens_after_head = self.teacher.dino_head(buffer_tensor_teacher) + teacher_cls_tokens_after_head = tokens_after_head[:n_cls_tokens] + masked_teacher_patch_tokens_after_head = tokens_after_head[ + n_cls_tokens : n_cls_tokens + n_masked_patches + ] + elif do_ibot and self.ibot_separate_head: + buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound, _dim) + torch.index_select( + ibot_teacher_patch_tokens.flatten(0, 1), + dim=0, + index=mask_indices_list, + out=buffer_tensor_teacher[:n_masked_patches], + ) + teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) + masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher)[ + :n_masked_patches + ] + else: + teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) + masked_teacher_ibot_softmaxed_centered = None + + if self.cfg.train.centering == "centering": + teacher_dino_softmaxed_centered_list = self.dino_loss.softmax_center_teacher( + teacher_cls_tokens_after_head, teacher_temp=teacher_temp + ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) + self.dino_loss.update_center(teacher_cls_tokens_after_head) + if do_ibot: + masked_teacher_patch_tokens_after_head = masked_teacher_patch_tokens_after_head.unsqueeze(0) + masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.softmax_center_teacher( + masked_teacher_patch_tokens_after_head[:, :n_masked_patches], teacher_temp=teacher_temp + ) + masked_teacher_ibot_softmaxed_centered = masked_teacher_ibot_softmaxed_centered.squeeze(0) + self.ibot_patch_loss.update_center(masked_teacher_patch_tokens_after_head[:n_masked_patches]) + + elif self.cfg.train.centering == "sinkhorn_knopp": + teacher_dino_softmaxed_centered_list = self.dino_loss.sinkhorn_knopp_teacher( + teacher_cls_tokens_after_head, teacher_temp=teacher_temp + ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) + + if do_ibot: + masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher( + masked_teacher_patch_tokens_after_head, + teacher_temp=teacher_temp, + n_masked_patches_tensor=n_masked_patches_tensor, + ) + + else: + raise NotImplementedError + + return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered + + teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output() + reshard_fsdp_model(self.teacher) + + loss_dict = {} + + loss_accumulator = 0 # for backprop + student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone( + [global_crops, local_crops], masks=[masks, None], is_training=True + ) + + inputs_for_student_head_list = [] + + # 1a: local crops cls tokens + student_local_cls_tokens = student_local_backbone_output_dict["x_norm_clstoken"] + inputs_for_student_head_list.append(student_local_cls_tokens.unsqueeze(0)) + + # 1b: global crops cls tokens + student_global_cls_tokens = student_global_backbone_output_dict["x_norm_clstoken"] + inputs_for_student_head_list.append(student_global_cls_tokens.unsqueeze(0)) + + # 1c: global crops patch tokens + if do_ibot: + _dim = student_global_backbone_output_dict["x_norm_clstoken"].shape[-1] + ibot_student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"] + buffer_tensor_patch_tokens = ibot_student_patch_tokens.new_zeros(upperbound, _dim) + buffer_tensor_patch_tokens[:n_masked_patches].copy_( + torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list) + ) + if not self.ibot_separate_head: + inputs_for_student_head_list.append(buffer_tensor_patch_tokens.unsqueeze(0)) + else: + student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens)[ + :n_masked_patches + ] + + # 2: run + _attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list) + outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs)) + + # 3a: local crops cls tokens + student_local_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) + + # 3b: global crops cls tokens + student_global_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) + + # 3c: global crops patch tokens + if do_ibot and not self.ibot_separate_head: + student_global_masked_patch_tokens_after_head = outputs_list.pop(0).squeeze(0)[:n_masked_patches] + + if n_local_crops > 0: + dino_local_crops_loss = self.dino_loss( + student_output_list=student_local_cls_tokens_after_head.chunk(n_local_crops), + teacher_out_softmaxed_centered_list=teacher_dino_softmaxed_centered_list, + ) / (n_global_crops_loss_terms + n_local_crops_loss_terms) + + # store for display + loss_dict["dino_local_crops_loss"] = dino_local_crops_loss + + # accumulate loss + loss_accumulator += self.dino_loss_weight * dino_local_crops_loss + + # process global crops + loss_scales = 2 # this is here since we process global crops together + + if do_dino: + # compute loss + dino_global_crops_loss = ( + self.dino_loss( + student_output_list=[student_global_cls_tokens_after_head], + teacher_out_softmaxed_centered_list=[ + teacher_dino_softmaxed_centered_list.flatten(0, 1) + ], # these were chunked and stacked in reverse so A is matched to B + ) + * loss_scales + / (n_global_crops_loss_terms + n_local_crops_loss_terms) + ) + + loss_dict["dino_global_crops_loss"] = dino_global_crops_loss + + # accumulate loss + loss_accumulator += self.dino_loss_weight * dino_global_crops_loss + + student_cls_tokens = student_global_cls_tokens + + if self.do_koleo: + koleo_loss = self.cfg.dino.koleo_loss_weight * sum( + self.koleo_loss(p) for p in student_cls_tokens.chunk(2) + ) # we don't apply koleo loss between cls tokens of a same image + loss_accumulator += koleo_loss + loss_dict["koleo_loss"] = ( + koleo_loss / loss_scales + ) # this is to display the same losses as before but we can remove eventually + + if do_ibot: + # compute loss + ibot_patch_loss = ( + self.ibot_patch_loss.forward_masked( + student_global_masked_patch_tokens_after_head, + masked_teacher_ibot_softmaxed_centered, + student_masks_flat=masks, + n_masked_patches=n_masked_patches, + masks_weight=masks_weight, + ) + * loss_scales + * ibot_loss_scale + ) + + # store for display + loss_dict["ibot_loss"] = ibot_patch_loss / 2 + + # accumulate loss + loss_accumulator += self.ibot_loss_weight * ibot_patch_loss + + self.backprop_loss(loss_accumulator) + + self.fsdp_synchronize_streams() + + return loss_dict + + def fsdp_synchronize_streams(self): + if self.need_to_synchronize_fsdp_streams: + torch.cuda.synchronize() + self.student.dino_head._streams = ( + self.teacher.dino_head._streams + ) = self.student.backbone._streams = self.teacher.backbone._streams + self.need_to_synchronize_fsdp_streams = False + + def update_teacher(self, m): + student_param_list = [] + teacher_param_list = [] + with torch.no_grad(): + for k in self.student.keys(): + for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])): + student_param_list += ms.params + teacher_param_list += mt.params + torch._foreach_mul_(teacher_param_list, m) + torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m) + + def train(self): + super().train() + self.teacher.eval() + + def get_maybe_fused_params_for_submodel(self, m): + params_groups = get_params_groups_with_decay( + model=m, + lr_decay_rate=self.cfg.optim.layerwise_decay, + patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult, + ) + fused_params_groups = fuse_params_groups(params_groups) + logger.info("fusing param groups") + + for g in fused_params_groups: + g["foreach"] = True + return fused_params_groups + + def get_params_groups(self): + all_params_groups = [] + for m in self.student.values(): + all_params_groups += self.get_maybe_fused_params_for_submodel(m) + return all_params_groups + + def prepare_for_distributed_training(self): + logger.info("DISTRIBUTED FSDP -- preparing model for distributed training") + if has_batchnorms(self.student): + raise NotImplementedError + # below will synchronize all student subnetworks across gpus: + for k, v in self.student.items(): + self.teacher[k].load_state_dict(self.student[k].state_dict()) + student_model_cfg = self.cfg.compute_precision.student[k] + self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k]) + teacher_model_cfg = self.cfg.compute_precision.teacher[k] + self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k]) diff --git a/src/dinov2/train/train.py b/src/dinov2/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..473b8d01473654182de9f91c94a2d8720fe096a5 --- /dev/null +++ b/src/dinov2/train/train.py @@ -0,0 +1,318 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import argparse +import logging +import math +import os +from functools import partial + +from fvcore.common.checkpoint import PeriodicCheckpointer +import torch + +from dinov2.data import SamplerType, make_data_loader, make_dataset +from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator +import dinov2.distributed as distributed +from dinov2.fsdp import FSDPCheckpointer +from dinov2.logging import MetricLogger +from dinov2.utils.config import setup +from dinov2.utils.utils import CosineScheduler + +from dinov2.train.ssl_meta_arch import SSLMetaArch + + +torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default +logger = logging.getLogger("dinov2") + + +def get_args_parser(add_help: bool = True): + parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help) + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument( + "--no-resume", + action="store_true", + help="Whether to not attempt to resume from the checkpoint directory. ", + ) + parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") + parser.add_argument("--eval", type=str, default="", help="Eval type to perform") + parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--output-dir", + "--output_dir", + default="", + type=str, + help="Output directory to save logs and checkpoints", + ) + + return parser + + +def build_optimizer(cfg, params_groups): + return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2)) + + +def build_schedulers(cfg): + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH + lr = dict( + base_value=cfg.optim["lr"], + final_value=cfg.optim["min_lr"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH, + start_warmup_value=0, + ) + wd = dict( + base_value=cfg.optim["weight_decay"], + final_value=cfg.optim["weight_decay_end"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + ) + momentum = dict( + base_value=cfg.teacher["momentum_teacher"], + final_value=cfg.teacher["final_momentum_teacher"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + ) + teacher_temp = dict( + base_value=cfg.teacher["teacher_temp"], + final_value=cfg.teacher["teacher_temp"], + total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, + warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, + start_warmup_value=cfg.teacher["warmup_teacher_temp"], + ) + + lr_schedule = CosineScheduler(**lr) + wd_schedule = CosineScheduler(**wd) + momentum_schedule = CosineScheduler(**momentum) + teacher_temp_schedule = CosineScheduler(**teacher_temp) + last_layer_lr_schedule = CosineScheduler(**lr) + + last_layer_lr_schedule.schedule[ + : cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH + ] = 0 # mimicking the original schedules + + logger.info("Schedulers ready.") + + return ( + lr_schedule, + wd_schedule, + momentum_schedule, + teacher_temp_schedule, + last_layer_lr_schedule, + ) + + +def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr): + for param_group in optimizer.param_groups: + is_last_layer = param_group["is_last_layer"] + lr_multiplier = param_group["lr_multiplier"] + wd_multiplier = param_group["wd_multiplier"] + param_group["weight_decay"] = wd * wd_multiplier + param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier + + +def do_test(cfg, model, iteration): + new_state_dict = model.teacher.state_dict() + + if distributed.is_main_process(): + iterstring = str(iteration) + eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring) + os.makedirs(eval_dir, exist_ok=True) + # save teacher checkpoint + teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth") + torch.save({"teacher": new_state_dict}, teacher_ckp_path) + + +def do_train(cfg, model, resume=False): + model.train() + inputs_dtype = torch.half + fp16_scaler = model.fp16_scaler # for mixed precision training + + # setup optimizer + + optimizer = build_optimizer(cfg, model.get_params_groups()) + ( + lr_schedule, + wd_schedule, + momentum_schedule, + teacher_temp_schedule, + last_layer_lr_schedule, + ) = build_schedulers(cfg) + + # checkpointer + checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True) + + start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 + + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH + max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH + + periodic_checkpointer = PeriodicCheckpointer( + checkpointer, + period=3 * OFFICIAL_EPOCH_LENGTH, + max_iter=max_iter, + max_to_keep=3, + ) + + # setup data preprocessing + + img_size = cfg.crops.global_crops_size + patch_size = cfg.student.patch_size + n_tokens = (img_size // patch_size) ** 2 + mask_generator = MaskingGenerator( + input_size=(img_size // patch_size, img_size // patch_size), + max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, + ) + + data_transform = DataAugmentationDINO( + cfg.crops.global_crops_scale, + cfg.crops.local_crops_scale, + cfg.crops.local_crops_number, + global_crops_size=cfg.crops.global_crops_size, + local_crops_size=cfg.crops.local_crops_size, + ) + + collate_fn = partial( + collate_data_and_cast, + mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, + mask_probability=cfg.ibot.mask_sample_probability, + n_tokens=n_tokens, + mask_generator=mask_generator, + dtype=inputs_dtype, + ) + + # setup data loader + + dataset = make_dataset( + dataset_str=cfg.train.dataset_path, + transform=data_transform, + target_transform=lambda _: (), + ) + # sampler_type = SamplerType.INFINITE + sampler_type = SamplerType.SHARDED_INFINITE + data_loader = make_data_loader( + dataset=dataset, + batch_size=cfg.train.batch_size_per_gpu, + num_workers=cfg.train.num_workers, + shuffle=True, + seed=start_iter, # TODO: Fix this -- cfg.train.seed + sampler_type=sampler_type, + sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu, + drop_last=True, + collate_fn=collate_fn, + ) + + # training loop + + iteration = start_iter + + logger.info("Starting training from iteration {}".format(start_iter)) + metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") + metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) + header = "Training" + + for data in metric_logger.log_every( + data_loader, + 10, + header, + max_iter, + start_iter, + ): + current_batch_size = data["collated_global_crops"].shape[0] / 2 + if iteration > max_iter: + return + + # apply schedules + + lr = lr_schedule[iteration] + wd = wd_schedule[iteration] + mom = momentum_schedule[iteration] + teacher_temp = teacher_temp_schedule[iteration] + last_layer_lr = last_layer_lr_schedule[iteration] + apply_optim_scheduler(optimizer, lr, wd, last_layer_lr) + + # compute losses + + optimizer.zero_grad(set_to_none=True) + loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) + + # clip gradients + + if fp16_scaler is not None: + if cfg.optim.clip_grad: + fp16_scaler.unscale_(optimizer) + for v in model.student.values(): + v.clip_grad_norm_(cfg.optim.clip_grad) + fp16_scaler.step(optimizer) + fp16_scaler.update() + else: + if cfg.optim.clip_grad: + for v in model.student.values(): + v.clip_grad_norm_(cfg.optim.clip_grad) + optimizer.step() + + # perform teacher EMA update + + model.update_teacher(mom) + + # logging + + if distributed.get_global_size() > 1: + for v in loss_dict.values(): + torch.distributed.all_reduce(v) + loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()} + + if math.isnan(sum(loss_dict_reduced.values())): + logger.info("NaN detected") + raise AssertionError + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + + metric_logger.update(lr=lr) + metric_logger.update(wd=wd) + metric_logger.update(mom=mom) + metric_logger.update(last_layer_lr=last_layer_lr) + metric_logger.update(current_batch_size=current_batch_size) + metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) + + # checkpointing and testing + + if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: + do_test(cfg, model, f"training_{iteration}") + torch.cuda.synchronize() + periodic_checkpointer.step(iteration) + + iteration = iteration + 1 + metric_logger.synchronize_between_processes() + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +def main(args): + cfg = setup(args) + + model = SSLMetaArch(cfg).to(torch.device("cuda")) + model.prepare_for_distributed_training() + + logger.info("Model:\n{}".format(model)) + if args.eval_only: + iteration = ( + FSDPCheckpointer(model, save_dir=cfg.train.output_dir) + .resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume) + .get("iteration", -1) + + 1 + ) + return do_test(cfg, model, f"manual_{iteration}") + + do_train(cfg, model, resume=not args.no_resume) + + +if __name__ == "__main__": + args = get_args_parser(add_help=True).parse_args() + main(args) diff --git a/src/dinov2/utils/__init__.py b/src/dinov2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/src/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/dinov2/utils/cluster.py b/src/dinov2/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314 --- /dev/null +++ b/src/dinov2/utils/cluster.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/src/dinov2/utils/config.py b/src/dinov2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52 --- /dev/null +++ b/src/dinov2/utils/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import dinov2.distributed as distributed +from dinov2.logging import setup_logging +from dinov2.utils import utils +from dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/src/dinov2/utils/dtype.py b/src/dinov2/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8 --- /dev/null +++ b/src/dinov2/utils/dtype.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/src/dinov2/utils/param_groups.py b/src/dinov2/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f --- /dev/null +++ b/src/dinov2/utils/param_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/src/dinov2/utils/utils.py b/src/dinov2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f --- /dev/null +++ b/src/dinov2/utils/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/src/model_LN_prompt.py b/src/model_LN_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..4e509c1296934bfff35519ce62ba76f71797b496 --- /dev/null +++ b/src/model_LN_prompt.py @@ -0,0 +1,68 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchmetrics.functional import retrieval_average_precision +import pytorch_lightning as pl + +from src.dinov2.models.vision_transformer import vit_base + +from functools import partial + +# from src.clip import clip +from src.options import opts + +def freeze_model(m): + m.requires_grad_(False) + +def freeze_all_but_bn(m): + if not isinstance(m, torch.nn.LayerNorm): + if hasattr(m, 'weight') and m.weight is not None: + m.weight.requires_grad_(False) + if hasattr(m, 'bias') and m.bias is not None: + m.bias.requires_grad_(False) + else: + print("LayerNorm") + +class Model(pl.LightningModule): + def __init__(self): + super().__init__() + + self.opts = opts + + self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0) + print("self.dino", self.dino) + + # Prompt Engineering + self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim)) + self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim)) + + self.distance_fn = lambda x, y: 1.0 - F.cosine_similarity(x, y) + self.loss_fn_triplet = nn.TripletMarginWithDistanceLoss( + distance_function=self.distance_fn, margin=0.2) + + self.emb_cos_loss = nn.CosineEmbeddingLoss(margin=0.2) + + self.loss_kl = nn.KLDivLoss(reduction="batchmean", log_target=True) + + self.best_metric = -1e3 + # normalization layer for the representations z1 and z2 + # self.bn = nn.BatchNorm1d(self.opts.prompt_dim, affine=False) + + def configure_optimizers(self): + if self.opts.model_type == 'one_encoder': + model_params = list(self.dino.parameters()) + else: + model_params = list(self.dino.parameters()) + list(self.clip_sk.parameters()) + + optimizer = torch.optim.Adam([ + {'params': model_params, 'lr': self.opts.clip_LN_lr}, + {'params': [self.sk_prompt] + [self.img_prompt], 'lr': self.opts.prompt_lr}]) + return optimizer + + def forward(self, data, dtype='image'): + if dtype == 'image': + feat = self.dino(data, prompt=self.img_prompt.expand(data.shape[0], -1, -1)) + else: + feat = self.dino(data, prompt=self.sk_prompt.expand(data.shape[0], -1, -1)) + return feat \ No newline at end of file diff --git a/src/options.py b/src/options.py new file mode 100644 index 0000000000000000000000000000000000000000..4912a36ea94dc25438f03b971bcef3967da880db --- /dev/null +++ b/src/options.py @@ -0,0 +1,23 @@ +import argparse + +parser = argparse.ArgumentParser(description='Sketch-based OD') + +parser.add_argument('--exp_name', type=str, default='LN_prompt') + +# ---------------------- +# Training Params +# ---------------------- + +parser.add_argument('--clip_lr', type=float, default=1e-4) +parser.add_argument('--clip_LN_lr', type=float, default=1e-6) +parser.add_argument('--prompt_lr', type=float, default=1e-4) +parser.add_argument('--linear_lr', type=float, default=1e-4) +parser.add_argument('--model_type', type=str, default='one_encoder', choices=['one_encoder', 'two_encoder']) + +# ---------------------- +# ViT Prompt Parameters +# ---------------------- +parser.add_argument('--prompt_dim', type=int, default=768) +parser.add_argument('--n_prompts', type=int, default=3) + +opts = parser.parse_args() \ No newline at end of file