CHSTR commited on
Commit
265ae36
1 Parent(s): 23f154c

Upload src

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/__pycache__/model_LN_prompt.cpython-310.pyc +0 -0
  2. src/__pycache__/options.cpython-310.pyc +0 -0
  3. src/dinov2/__init__.py +6 -0
  4. src/dinov2/__pycache__/__init__.cpython-310.pyc +0 -0
  5. src/dinov2/configs/__init__.py +22 -0
  6. src/dinov2/configs/eval/vitb14_pretrain.yaml +6 -0
  7. src/dinov2/configs/eval/vitb14_reg4_pretrain.yaml +9 -0
  8. src/dinov2/configs/eval/vitg14_pretrain.yaml +7 -0
  9. src/dinov2/configs/eval/vitg14_reg4_pretrain.yaml +10 -0
  10. src/dinov2/configs/eval/vitl14_pretrain.yaml +6 -0
  11. src/dinov2/configs/eval/vitl14_reg4_pretrain.yaml +9 -0
  12. src/dinov2/configs/eval/vits14_pretrain.yaml +6 -0
  13. src/dinov2/configs/eval/vits14_reg4_pretrain.yaml +9 -0
  14. src/dinov2/configs/ssl_default_config.yaml +118 -0
  15. src/dinov2/configs/train/vitg14.yaml +26 -0
  16. src/dinov2/configs/train/vitl14.yaml +26 -0
  17. src/dinov2/configs/train/vitl16_short.yaml +6 -0
  18. src/dinov2/data/__init__.py +10 -0
  19. src/dinov2/data/adapters.py +28 -0
  20. src/dinov2/data/augmentations.py +118 -0
  21. src/dinov2/data/collate.py +49 -0
  22. src/dinov2/data/datasets/__init__.py +7 -0
  23. src/dinov2/data/datasets/decoders.py +31 -0
  24. src/dinov2/data/datasets/extended.py +38 -0
  25. src/dinov2/data/datasets/image_net.py +290 -0
  26. src/dinov2/data/datasets/image_net_22k.py +302 -0
  27. src/dinov2/data/loaders.py +222 -0
  28. src/dinov2/data/masking.py +86 -0
  29. src/dinov2/data/samplers.py +229 -0
  30. src/dinov2/data/transforms.py +91 -0
  31. src/dinov2/distributed/__init__.py +270 -0
  32. src/dinov2/eval/__init__.py +4 -0
  33. src/dinov2/eval/depth/__init__.py +4 -0
  34. src/dinov2/eval/depth/models/__init__.py +10 -0
  35. src/dinov2/eval/depth/models/backbones/__init__.py +6 -0
  36. src/dinov2/eval/depth/models/backbones/vision_transformer.py +16 -0
  37. src/dinov2/eval/depth/models/builder.py +49 -0
  38. src/dinov2/eval/depth/models/decode_heads/__init__.py +7 -0
  39. src/dinov2/eval/depth/models/decode_heads/decode_head.py +225 -0
  40. src/dinov2/eval/depth/models/decode_heads/dpt_head.py +270 -0
  41. src/dinov2/eval/depth/models/decode_heads/linear_head.py +89 -0
  42. src/dinov2/eval/depth/models/depther/__init__.py +7 -0
  43. src/dinov2/eval/depth/models/depther/base.py +194 -0
  44. src/dinov2/eval/depth/models/depther/encoder_decoder.py +236 -0
  45. src/dinov2/eval/depth/models/losses/__init__.py +7 -0
  46. src/dinov2/eval/depth/models/losses/gradientloss.py +69 -0
  47. src/dinov2/eval/depth/models/losses/sigloss.py +65 -0
  48. src/dinov2/eval/depth/ops/__init__.py +6 -0
  49. src/dinov2/eval/depth/ops/wrappers.py +28 -0
  50. src/dinov2/eval/knn.py +404 -0
src/__pycache__/model_LN_prompt.cpython-310.pyc ADDED
Binary file (2.7 kB). View file
 
src/__pycache__/options.cpython-310.pyc ADDED
Binary file (634 Bytes). View file
 
src/dinov2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
src/dinov2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (157 Bytes). View file
 
src/dinov2/configs/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import pathlib
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ def load_config(config_name: str):
12
+ config_filename = config_name + ".yaml"
13
+ return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
14
+
15
+
16
+ dinov2_default_config = load_config("ssl_default_config")
17
+
18
+
19
+ def load_and_merge_config(config_name: str):
20
+ default_config = OmegaConf.create(dinov2_default_config)
21
+ loaded_config = load_config(config_name)
22
+ return OmegaConf.merge(default_config, loaded_config)
src/dinov2/configs/eval/vitb14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_base
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
src/dinov2/configs/eval/vitb14_reg4_pretrain.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_base
3
+ patch_size: 14
4
+ num_register_tokens: 4
5
+ interpolate_antialias: true
6
+ interpolate_offset: 0.0
7
+ crops:
8
+ global_crops_size: 518 # this is to set up the position embeddings properly
9
+ local_crops_size: 98
src/dinov2/configs/eval/vitg14_pretrain.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_giant2
3
+ patch_size: 14
4
+ ffn_layer: swiglufused
5
+ crops:
6
+ global_crops_size: 518 # this is to set up the position embeddings properly
7
+ local_crops_size: 98
src/dinov2/configs/eval/vitg14_reg4_pretrain.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_giant2
3
+ patch_size: 14
4
+ ffn_layer: swiglufused
5
+ num_register_tokens: 4
6
+ interpolate_antialias: true
7
+ interpolate_offset: 0.0
8
+ crops:
9
+ global_crops_size: 518 # this is to set up the position embeddings properly
10
+ local_crops_size: 98
src/dinov2/configs/eval/vitl14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_large
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
src/dinov2/configs/eval/vitl14_reg4_pretrain.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_large
3
+ patch_size: 14
4
+ num_register_tokens: 4
5
+ interpolate_antialias: true
6
+ interpolate_offset: 0.0
7
+ crops:
8
+ global_crops_size: 518 # this is to set up the position embeddings properly
9
+ local_crops_size: 98
src/dinov2/configs/eval/vits14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_small
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
src/dinov2/configs/eval/vits14_reg4_pretrain.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_small
3
+ patch_size: 14
4
+ num_register_tokens: 4
5
+ interpolate_antialias: true
6
+ interpolate_offset: 0.0
7
+ crops:
8
+ global_crops_size: 518 # this is to set up the position embeddings properly
9
+ local_crops_size: 98
src/dinov2/configs/ssl_default_config.yaml ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ WEIGHTS: ''
3
+ compute_precision:
4
+ grad_scaler: true
5
+ teacher:
6
+ backbone:
7
+ sharding_strategy: SHARD_GRAD_OP
8
+ mixed_precision:
9
+ param_dtype: fp16
10
+ reduce_dtype: fp16
11
+ buffer_dtype: fp32
12
+ dino_head:
13
+ sharding_strategy: SHARD_GRAD_OP
14
+ mixed_precision:
15
+ param_dtype: fp16
16
+ reduce_dtype: fp16
17
+ buffer_dtype: fp32
18
+ ibot_head:
19
+ sharding_strategy: SHARD_GRAD_OP
20
+ mixed_precision:
21
+ param_dtype: fp16
22
+ reduce_dtype: fp16
23
+ buffer_dtype: fp32
24
+ student:
25
+ backbone:
26
+ sharding_strategy: SHARD_GRAD_OP
27
+ mixed_precision:
28
+ param_dtype: fp16
29
+ reduce_dtype: fp16
30
+ buffer_dtype: fp32
31
+ dino_head:
32
+ sharding_strategy: SHARD_GRAD_OP
33
+ mixed_precision:
34
+ param_dtype: fp16
35
+ reduce_dtype: fp32
36
+ buffer_dtype: fp32
37
+ ibot_head:
38
+ sharding_strategy: SHARD_GRAD_OP
39
+ mixed_precision:
40
+ param_dtype: fp16
41
+ reduce_dtype: fp32
42
+ buffer_dtype: fp32
43
+ dino:
44
+ loss_weight: 1.0
45
+ head_n_prototypes: 65536
46
+ head_bottleneck_dim: 256
47
+ head_nlayers: 3
48
+ head_hidden_dim: 2048
49
+ koleo_loss_weight: 0.1
50
+ ibot:
51
+ loss_weight: 1.0
52
+ mask_sample_probability: 0.5
53
+ mask_ratio_min_max:
54
+ - 0.1
55
+ - 0.5
56
+ separate_head: false
57
+ head_n_prototypes: 65536
58
+ head_bottleneck_dim: 256
59
+ head_nlayers: 3
60
+ head_hidden_dim: 2048
61
+ train:
62
+ batch_size_per_gpu: 64
63
+ dataset_path: ImageNet:split=TRAIN
64
+ output_dir: .
65
+ saveckp_freq: 20
66
+ seed: 0
67
+ num_workers: 10
68
+ OFFICIAL_EPOCH_LENGTH: 1250
69
+ cache_dataset: true
70
+ centering: "centering" # or "sinkhorn_knopp"
71
+ student:
72
+ arch: vit_large
73
+ patch_size: 16
74
+ drop_path_rate: 0.3
75
+ layerscale: 1.0e-05
76
+ drop_path_uniform: true
77
+ pretrained_weights: ''
78
+ ffn_layer: "mlp"
79
+ block_chunks: 0
80
+ qkv_bias: true
81
+ proj_bias: true
82
+ ffn_bias: true
83
+ num_register_tokens: 0
84
+ interpolate_antialias: false
85
+ interpolate_offset: 0.1
86
+ teacher:
87
+ momentum_teacher: 0.992
88
+ final_momentum_teacher: 1
89
+ warmup_teacher_temp: 0.04
90
+ teacher_temp: 0.07
91
+ warmup_teacher_temp_epochs: 30
92
+ optim:
93
+ epochs: 100
94
+ weight_decay: 0.04
95
+ weight_decay_end: 0.4
96
+ base_lr: 0.004 # learning rate for a batch size of 1024
97
+ lr: 0. # will be set after applying scaling rule
98
+ warmup_epochs: 10
99
+ min_lr: 1.0e-06
100
+ clip_grad: 3.0
101
+ freeze_last_layer_epochs: 1
102
+ scaling_rule: sqrt_wrt_1024
103
+ patch_embed_lr_mult: 0.2
104
+ layerwise_decay: 0.9
105
+ adamw_beta1: 0.9
106
+ adamw_beta2: 0.999
107
+ crops:
108
+ global_crops_scale:
109
+ - 0.32
110
+ - 1.0
111
+ local_crops_number: 8
112
+ local_crops_scale:
113
+ - 0.05
114
+ - 0.32
115
+ global_crops_size: 224
116
+ local_crops_size: 96
117
+ evaluation:
118
+ eval_period_iterations: 12500
src/dinov2/configs/train/vitg14.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dino:
2
+ head_n_prototypes: 131072
3
+ head_bottleneck_dim: 384
4
+ ibot:
5
+ separate_head: true
6
+ head_n_prototypes: 131072
7
+ train:
8
+ batch_size_per_gpu: 12
9
+ dataset_path: ImageNet22k
10
+ centering: sinkhorn_knopp
11
+ student:
12
+ arch: vit_giant2
13
+ patch_size: 14
14
+ drop_path_rate: 0.4
15
+ ffn_layer: swiglufused
16
+ block_chunks: 4
17
+ teacher:
18
+ momentum_teacher: 0.994
19
+ optim:
20
+ epochs: 500
21
+ weight_decay_end: 0.2
22
+ base_lr: 2.0e-04 # learning rate for a batch size of 1024
23
+ warmup_epochs: 80
24
+ layerwise_decay: 1.0
25
+ crops:
26
+ local_crops_size: 98
src/dinov2/configs/train/vitl14.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dino:
2
+ head_n_prototypes: 131072
3
+ head_bottleneck_dim: 384
4
+ ibot:
5
+ separate_head: true
6
+ head_n_prototypes: 131072
7
+ train:
8
+ batch_size_per_gpu: 32
9
+ dataset_path: ImageNet22k
10
+ centering: sinkhorn_knopp
11
+ student:
12
+ arch: vit_large
13
+ patch_size: 14
14
+ drop_path_rate: 0.4
15
+ ffn_layer: swiglufused
16
+ block_chunks: 4
17
+ teacher:
18
+ momentum_teacher: 0.994
19
+ optim:
20
+ epochs: 500
21
+ weight_decay_end: 0.2
22
+ base_lr: 2.0e-04 # learning rate for a batch size of 1024
23
+ warmup_epochs: 80
24
+ layerwise_decay: 1.0
25
+ crops:
26
+ local_crops_size: 98
src/dinov2/configs/train/vitl16_short.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # this corresponds to the default config
2
+ train:
3
+ dataset_path: ImageNet:split=TRAIN
4
+ batch_size_per_gpu: 64
5
+ student:
6
+ block_chunks: 4
src/dinov2/data/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .adapters import DatasetWithEnumeratedTargets
7
+ from .loaders import make_data_loader, make_dataset, SamplerType
8
+ from .collate import collate_data_and_cast
9
+ from .masking import MaskingGenerator
10
+ from .augmentations import DataAugmentationDINO
src/dinov2/data/adapters.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, Tuple
7
+
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ class DatasetWithEnumeratedTargets(Dataset):
12
+ def __init__(self, dataset):
13
+ self._dataset = dataset
14
+
15
+ def get_image_data(self, index: int) -> bytes:
16
+ return self._dataset.get_image_data(index)
17
+
18
+ def get_target(self, index: int) -> Tuple[Any, int]:
19
+ target = self._dataset.get_target(index)
20
+ return (index, target)
21
+
22
+ def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
23
+ image, target = self._dataset[index]
24
+ target = index if target is None else target
25
+ return image, (index, target)
26
+
27
+ def __len__(self) -> int:
28
+ return len(self._dataset)
src/dinov2/data/augmentations.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from torchvision import transforms
9
+
10
+ from .transforms import (
11
+ GaussianBlur,
12
+ make_normalize_transform,
13
+ )
14
+
15
+
16
+ logger = logging.getLogger("dinov2")
17
+
18
+
19
+ class DataAugmentationDINO(object):
20
+ def __init__(
21
+ self,
22
+ global_crops_scale,
23
+ local_crops_scale,
24
+ local_crops_number,
25
+ global_crops_size=224,
26
+ local_crops_size=96,
27
+ ):
28
+ self.global_crops_scale = global_crops_scale
29
+ self.local_crops_scale = local_crops_scale
30
+ self.local_crops_number = local_crops_number
31
+ self.global_crops_size = global_crops_size
32
+ self.local_crops_size = local_crops_size
33
+
34
+ logger.info("###################################")
35
+ logger.info("Using data augmentation parameters:")
36
+ logger.info(f"global_crops_scale: {global_crops_scale}")
37
+ logger.info(f"local_crops_scale: {local_crops_scale}")
38
+ logger.info(f"local_crops_number: {local_crops_number}")
39
+ logger.info(f"global_crops_size: {global_crops_size}")
40
+ logger.info(f"local_crops_size: {local_crops_size}")
41
+ logger.info("###################################")
42
+
43
+ # random resized crop and flip
44
+ self.geometric_augmentation_global = transforms.Compose(
45
+ [
46
+ transforms.RandomResizedCrop(
47
+ global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
48
+ ),
49
+ transforms.RandomHorizontalFlip(p=0.5),
50
+ ]
51
+ )
52
+
53
+ self.geometric_augmentation_local = transforms.Compose(
54
+ [
55
+ transforms.RandomResizedCrop(
56
+ local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
57
+ ),
58
+ transforms.RandomHorizontalFlip(p=0.5),
59
+ ]
60
+ )
61
+
62
+ # color distorsions / blurring
63
+ color_jittering = transforms.Compose(
64
+ [
65
+ transforms.RandomApply(
66
+ [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
67
+ p=0.8,
68
+ ),
69
+ transforms.RandomGrayscale(p=0.2),
70
+ ]
71
+ )
72
+
73
+ global_transfo1_extra = GaussianBlur(p=1.0)
74
+
75
+ global_transfo2_extra = transforms.Compose(
76
+ [
77
+ GaussianBlur(p=0.1),
78
+ transforms.RandomSolarize(threshold=128, p=0.2),
79
+ ]
80
+ )
81
+
82
+ local_transfo_extra = GaussianBlur(p=0.5)
83
+
84
+ # normalization
85
+ self.normalize = transforms.Compose(
86
+ [
87
+ transforms.ToTensor(),
88
+ make_normalize_transform(),
89
+ ]
90
+ )
91
+
92
+ self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
93
+ self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
94
+ self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
95
+
96
+ def __call__(self, image):
97
+ output = {}
98
+
99
+ # global crops:
100
+ im1_base = self.geometric_augmentation_global(image)
101
+ global_crop_1 = self.global_transfo1(im1_base)
102
+
103
+ im2_base = self.geometric_augmentation_global(image)
104
+ global_crop_2 = self.global_transfo2(im2_base)
105
+
106
+ output["global_crops"] = [global_crop_1, global_crop_2]
107
+
108
+ # global crops for teacher:
109
+ output["global_crops_teacher"] = [global_crop_1, global_crop_2]
110
+
111
+ # local crops:
112
+ local_crops = [
113
+ self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
114
+ ]
115
+ output["local_crops"] = local_crops
116
+ output["offsets"] = ()
117
+
118
+ return output
src/dinov2/data/collate.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import random
8
+
9
+
10
+ def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
11
+ # dtype = torch.half # TODO: Remove
12
+
13
+ n_global_crops = len(samples_list[0][0]["global_crops"])
14
+ n_local_crops = len(samples_list[0][0]["local_crops"])
15
+
16
+ collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
17
+
18
+ collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
19
+
20
+ B = len(collated_global_crops)
21
+ N = n_tokens
22
+ n_samples_masked = int(B * mask_probability)
23
+ probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
24
+ upperbound = 0
25
+ masks_list = []
26
+ for i in range(0, n_samples_masked):
27
+ prob_min = probs[i]
28
+ prob_max = probs[i + 1]
29
+ masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
30
+ upperbound += int(N * prob_max)
31
+ for i in range(n_samples_masked, B):
32
+ masks_list.append(torch.BoolTensor(mask_generator(0)))
33
+
34
+ random.shuffle(masks_list)
35
+
36
+ collated_masks = torch.stack(masks_list).flatten(1)
37
+ mask_indices_list = collated_masks.flatten().nonzero().flatten()
38
+
39
+ masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
40
+
41
+ return {
42
+ "collated_global_crops": collated_global_crops.to(dtype),
43
+ "collated_local_crops": collated_local_crops.to(dtype),
44
+ "collated_masks": collated_masks,
45
+ "mask_indices_list": mask_indices_list,
46
+ "masks_weight": masks_weight,
47
+ "upperbound": upperbound,
48
+ "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
49
+ }
src/dinov2/data/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .image_net import ImageNet
7
+ from .image_net_22k import ImageNet22k
src/dinov2/data/datasets/decoders.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from io import BytesIO
7
+ from typing import Any
8
+
9
+ from PIL import Image
10
+
11
+
12
+ class Decoder:
13
+ def decode(self) -> Any:
14
+ raise NotImplementedError
15
+
16
+
17
+ class ImageDataDecoder(Decoder):
18
+ def __init__(self, image_data: bytes) -> None:
19
+ self._image_data = image_data
20
+
21
+ def decode(self) -> Image:
22
+ f = BytesIO(self._image_data)
23
+ return Image.open(f).convert(mode="RGB")
24
+
25
+
26
+ class TargetDecoder(Decoder):
27
+ def __init__(self, target: Any):
28
+ self._target = target
29
+
30
+ def decode(self) -> Any:
31
+ return self._target
src/dinov2/data/datasets/extended.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, Tuple
7
+
8
+ from torchvision.datasets import VisionDataset
9
+
10
+ from .decoders import TargetDecoder, ImageDataDecoder
11
+
12
+
13
+ class ExtendedVisionDataset(VisionDataset):
14
+ def __init__(self, *args, **kwargs) -> None:
15
+ super().__init__(*args, **kwargs) # type: ignore
16
+
17
+ def get_image_data(self, index: int) -> bytes:
18
+ raise NotImplementedError
19
+
20
+ def get_target(self, index: int) -> Any:
21
+ raise NotImplementedError
22
+
23
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
24
+ try:
25
+ image_data = self.get_image_data(index)
26
+ image = ImageDataDecoder(image_data).decode()
27
+ except Exception as e:
28
+ raise RuntimeError(f"can not read image for sample {index}") from e
29
+ target = self.get_target(index)
30
+ target = TargetDecoder(target).decode()
31
+
32
+ if self.transforms is not None:
33
+ image, target = self.transforms(image, target)
34
+
35
+ return image, target
36
+
37
+ def __len__(self) -> int:
38
+ raise NotImplementedError
src/dinov2/data/datasets/image_net.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ from enum import Enum
8
+ import logging
9
+ import os
10
+ from typing import Callable, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+
14
+ from .extended import ExtendedVisionDataset
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+ _Target = int
19
+
20
+
21
+ class _Split(Enum):
22
+ TRAIN = "train"
23
+ VAL = "val"
24
+ TEST = "test" # NOTE: torchvision does not support the test split
25
+
26
+ @property
27
+ def length(self) -> int:
28
+ split_lengths = {
29
+ _Split.TRAIN: 1_281_167,
30
+ _Split.VAL: 50_000,
31
+ _Split.TEST: 100_000,
32
+ }
33
+ return split_lengths[self]
34
+
35
+ def get_dirname(self, class_id: Optional[str] = None) -> str:
36
+ return self.value if class_id is None else os.path.join(self.value, class_id)
37
+
38
+ def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
39
+ dirname = self.get_dirname(class_id)
40
+ if self == _Split.TRAIN:
41
+ basename = f"{class_id}_{actual_index}"
42
+ else: # self in (_Split.VAL, _Split.TEST):
43
+ basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
44
+ return os.path.join(dirname, basename + ".JPEG")
45
+
46
+ def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
47
+ assert self != _Split.TEST
48
+ dirname, filename = os.path.split(image_relpath)
49
+ class_id = os.path.split(dirname)[-1]
50
+ basename, _ = os.path.splitext(filename)
51
+ actual_index = int(basename.split("_")[-1])
52
+ return class_id, actual_index
53
+
54
+
55
+ class ImageNet(ExtendedVisionDataset):
56
+ Target = Union[_Target]
57
+ Split = Union[_Split]
58
+
59
+ def __init__(
60
+ self,
61
+ *,
62
+ split: "ImageNet.Split",
63
+ root: str,
64
+ extra: str,
65
+ transforms: Optional[Callable] = None,
66
+ transform: Optional[Callable] = None,
67
+ target_transform: Optional[Callable] = None,
68
+ ) -> None:
69
+ super().__init__(root, transforms, transform, target_transform)
70
+ self._extra_root = extra
71
+ self._split = split
72
+
73
+ self._entries = None
74
+ self._class_ids = None
75
+ self._class_names = None
76
+
77
+ @property
78
+ def split(self) -> "ImageNet.Split":
79
+ return self._split
80
+
81
+ def _get_extra_full_path(self, extra_path: str) -> str:
82
+ return os.path.join(self._extra_root, extra_path)
83
+
84
+ def _load_extra(self, extra_path: str) -> np.ndarray:
85
+ extra_full_path = self._get_extra_full_path(extra_path)
86
+ return np.load(extra_full_path, mmap_mode="r")
87
+
88
+ def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
89
+ extra_full_path = self._get_extra_full_path(extra_path)
90
+ os.makedirs(self._extra_root, exist_ok=True)
91
+ np.save(extra_full_path, extra_array)
92
+
93
+ @property
94
+ def _entries_path(self) -> str:
95
+ return f"entries-{self._split.value.upper()}.npy"
96
+
97
+ @property
98
+ def _class_ids_path(self) -> str:
99
+ return f"class-ids-{self._split.value.upper()}.npy"
100
+
101
+ @property
102
+ def _class_names_path(self) -> str:
103
+ return f"class-names-{self._split.value.upper()}.npy"
104
+
105
+ def _get_entries(self) -> np.ndarray:
106
+ if self._entries is None:
107
+ self._entries = self._load_extra(self._entries_path)
108
+ assert self._entries is not None
109
+ return self._entries
110
+
111
+ def _get_class_ids(self) -> np.ndarray:
112
+ if self._split == _Split.TEST:
113
+ assert False, "Class IDs are not available in TEST split"
114
+ if self._class_ids is None:
115
+ self._class_ids = self._load_extra(self._class_ids_path)
116
+ assert self._class_ids is not None
117
+ return self._class_ids
118
+
119
+ def _get_class_names(self) -> np.ndarray:
120
+ if self._split == _Split.TEST:
121
+ assert False, "Class names are not available in TEST split"
122
+ if self._class_names is None:
123
+ self._class_names = self._load_extra(self._class_names_path)
124
+ assert self._class_names is not None
125
+ return self._class_names
126
+
127
+ def find_class_id(self, class_index: int) -> str:
128
+ class_ids = self._get_class_ids()
129
+ return str(class_ids[class_index])
130
+
131
+ def find_class_name(self, class_index: int) -> str:
132
+ class_names = self._get_class_names()
133
+ return str(class_names[class_index])
134
+
135
+ def get_image_data(self, index: int) -> bytes:
136
+ entries = self._get_entries()
137
+ actual_index = entries[index]["actual_index"]
138
+
139
+ class_id = self.get_class_id(index)
140
+
141
+ image_relpath = self.split.get_image_relpath(actual_index, class_id)
142
+ image_full_path = os.path.join(self.root, image_relpath)
143
+ with open(image_full_path, mode="rb") as f:
144
+ image_data = f.read()
145
+ return image_data
146
+
147
+ def get_target(self, index: int) -> Optional[Target]:
148
+ entries = self._get_entries()
149
+ class_index = entries[index]["class_index"]
150
+ return None if self.split == _Split.TEST else int(class_index)
151
+
152
+ def get_targets(self) -> Optional[np.ndarray]:
153
+ entries = self._get_entries()
154
+ return None if self.split == _Split.TEST else entries["class_index"]
155
+
156
+ def get_class_id(self, index: int) -> Optional[str]:
157
+ entries = self._get_entries()
158
+ class_id = entries[index]["class_id"]
159
+ return None if self.split == _Split.TEST else str(class_id)
160
+
161
+ def get_class_name(self, index: int) -> Optional[str]:
162
+ entries = self._get_entries()
163
+ class_name = entries[index]["class_name"]
164
+ return None if self.split == _Split.TEST else str(class_name)
165
+
166
+ def __len__(self) -> int:
167
+ entries = self._get_entries()
168
+ assert len(entries) == self.split.length
169
+ return len(entries)
170
+
171
+ def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
172
+ labels_full_path = os.path.join(self.root, labels_path)
173
+ labels = []
174
+
175
+ try:
176
+ with open(labels_full_path, "r") as f:
177
+ reader = csv.reader(f)
178
+ for row in reader:
179
+ class_id, class_name = row
180
+ labels.append((class_id, class_name))
181
+ except OSError as e:
182
+ raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
183
+
184
+ return labels
185
+
186
+ def _dump_entries(self) -> None:
187
+ split = self.split
188
+ if split == ImageNet.Split.TEST:
189
+ dataset = None
190
+ sample_count = split.length
191
+ max_class_id_length, max_class_name_length = 0, 0
192
+ else:
193
+ labels_path = "labels.txt"
194
+ logger.info(f'loading labels from "{labels_path}"')
195
+ labels = self._load_labels(labels_path)
196
+
197
+ # NOTE: Using torchvision ImageFolder for consistency
198
+ from torchvision.datasets import ImageFolder
199
+
200
+ dataset_root = os.path.join(self.root, split.get_dirname())
201
+ dataset = ImageFolder(dataset_root)
202
+ sample_count = len(dataset)
203
+ max_class_id_length, max_class_name_length = -1, -1
204
+ for sample in dataset.samples:
205
+ _, class_index = sample
206
+ class_id, class_name = labels[class_index]
207
+ max_class_id_length = max(len(class_id), max_class_id_length)
208
+ max_class_name_length = max(len(class_name), max_class_name_length)
209
+
210
+ dtype = np.dtype(
211
+ [
212
+ ("actual_index", "<u4"),
213
+ ("class_index", "<u4"),
214
+ ("class_id", f"U{max_class_id_length}"),
215
+ ("class_name", f"U{max_class_name_length}"),
216
+ ]
217
+ )
218
+ entries_array = np.empty(sample_count, dtype=dtype)
219
+
220
+ if split == ImageNet.Split.TEST:
221
+ old_percent = -1
222
+ for index in range(sample_count):
223
+ percent = 100 * (index + 1) // sample_count
224
+ if percent > old_percent:
225
+ logger.info(f"creating entries: {percent}%")
226
+ old_percent = percent
227
+
228
+ actual_index = index + 1
229
+ class_index = np.uint32(-1)
230
+ class_id, class_name = "", ""
231
+ entries_array[index] = (actual_index, class_index, class_id, class_name)
232
+ else:
233
+ class_names = {class_id: class_name for class_id, class_name in labels}
234
+
235
+ assert dataset
236
+ old_percent = -1
237
+ for index in range(sample_count):
238
+ percent = 100 * (index + 1) // sample_count
239
+ if percent > old_percent:
240
+ logger.info(f"creating entries: {percent}%")
241
+ old_percent = percent
242
+
243
+ image_full_path, class_index = dataset.samples[index]
244
+ image_relpath = os.path.relpath(image_full_path, self.root)
245
+ class_id, actual_index = split.parse_image_relpath(image_relpath)
246
+ class_name = class_names[class_id]
247
+ entries_array[index] = (actual_index, class_index, class_id, class_name)
248
+
249
+ logger.info(f'saving entries to "{self._entries_path}"')
250
+ self._save_extra(entries_array, self._entries_path)
251
+
252
+ def _dump_class_ids_and_names(self) -> None:
253
+ split = self.split
254
+ if split == ImageNet.Split.TEST:
255
+ return
256
+
257
+ entries_array = self._load_extra(self._entries_path)
258
+
259
+ max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
260
+ for entry in entries_array:
261
+ class_index, class_id, class_name = (
262
+ entry["class_index"],
263
+ entry["class_id"],
264
+ entry["class_name"],
265
+ )
266
+ max_class_index = max(int(class_index), max_class_index)
267
+ max_class_id_length = max(len(str(class_id)), max_class_id_length)
268
+ max_class_name_length = max(len(str(class_name)), max_class_name_length)
269
+
270
+ class_count = max_class_index + 1
271
+ class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
272
+ class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
273
+ for entry in entries_array:
274
+ class_index, class_id, class_name = (
275
+ entry["class_index"],
276
+ entry["class_id"],
277
+ entry["class_name"],
278
+ )
279
+ class_ids_array[class_index] = class_id
280
+ class_names_array[class_index] = class_name
281
+
282
+ logger.info(f'saving class IDs to "{self._class_ids_path}"')
283
+ self._save_extra(class_ids_array, self._class_ids_path)
284
+
285
+ logger.info(f'saving class names to "{self._class_names_path}"')
286
+ self._save_extra(class_names_array, self._class_names_path)
287
+
288
+ def dump_extra(self) -> None:
289
+ self._dump_entries()
290
+ self._dump_class_ids_and_names()
src/dinov2/data/datasets/image_net_22k.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from functools import lru_cache
9
+ from gzip import GzipFile
10
+ from io import BytesIO
11
+ from mmap import ACCESS_READ, mmap
12
+ import os
13
+ from typing import Any, Callable, List, Optional, Set, Tuple
14
+ import warnings
15
+
16
+ import numpy as np
17
+
18
+ from .extended import ExtendedVisionDataset
19
+
20
+
21
+ _Labels = int
22
+
23
+ _DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
24
+
25
+
26
+ @dataclass
27
+ class _ClassEntry:
28
+ block_offset: int
29
+ maybe_filename: Optional[str] = None
30
+
31
+
32
+ @dataclass
33
+ class _Entry:
34
+ class_index: int # noqa: E701
35
+ start_offset: int
36
+ end_offset: int
37
+ filename: str
38
+
39
+
40
+ class _Split(Enum):
41
+ TRAIN = "train"
42
+ VAL = "val"
43
+
44
+ @property
45
+ def length(self) -> int:
46
+ return {
47
+ _Split.TRAIN: 11_797_647,
48
+ _Split.VAL: 561_050,
49
+ }[self]
50
+
51
+ def entries_path(self):
52
+ return f"imagenet21kp_{self.value}.txt"
53
+
54
+
55
+ def _get_tarball_path(class_id: str) -> str:
56
+ return f"{class_id}.tar"
57
+
58
+
59
+ def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
60
+ @lru_cache(maxsize=mmap_cache_size)
61
+ def _mmap_tarball(class_id: str) -> mmap:
62
+ tarball_path = _get_tarball_path(class_id)
63
+ tarball_full_path = os.path.join(tarballs_root, tarball_path)
64
+ with open(tarball_full_path) as f:
65
+ return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
66
+
67
+ return _mmap_tarball
68
+
69
+
70
+ class ImageNet22k(ExtendedVisionDataset):
71
+ _GZIPPED_INDICES: Set[int] = {
72
+ 841_545,
73
+ 1_304_131,
74
+ 2_437_921,
75
+ 2_672_079,
76
+ 2_795_676,
77
+ 2_969_786,
78
+ 6_902_965,
79
+ 6_903_550,
80
+ 6_903_628,
81
+ 7_432_557,
82
+ 7_432_589,
83
+ 7_813_809,
84
+ 8_329_633,
85
+ 10_296_990,
86
+ 10_417_652,
87
+ 10_492_265,
88
+ 10_598_078,
89
+ 10_782_398,
90
+ 10_902_612,
91
+ 11_203_736,
92
+ 11_342_890,
93
+ 11_397_596,
94
+ 11_589_762,
95
+ 11_705_103,
96
+ 12_936_875,
97
+ 13_289_782,
98
+ }
99
+ Labels = _Labels
100
+
101
+ def __init__(
102
+ self,
103
+ *,
104
+ root: str,
105
+ extra: str,
106
+ transforms: Optional[Callable] = None,
107
+ transform: Optional[Callable] = None,
108
+ target_transform: Optional[Callable] = None,
109
+ mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
110
+ ) -> None:
111
+ super().__init__(root, transforms, transform, target_transform)
112
+ self._extra_root = extra
113
+
114
+ entries_path = self._get_entries_path(root)
115
+ self._entries = self._load_extra(entries_path)
116
+
117
+ class_ids_path = self._get_class_ids_path(root)
118
+ self._class_ids = self._load_extra(class_ids_path)
119
+
120
+ self._gzipped_indices = ImageNet22k._GZIPPED_INDICES
121
+ self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)
122
+
123
+ def _get_entries_path(self, root: Optional[str] = None) -> str:
124
+ return "entries.npy"
125
+
126
+ def _get_class_ids_path(self, root: Optional[str] = None) -> str:
127
+ return "class-ids.npy"
128
+
129
+ def _find_class_ids(self, path: str) -> List[str]:
130
+ class_ids = []
131
+
132
+ with os.scandir(path) as entries:
133
+ for entry in entries:
134
+ root, ext = os.path.splitext(entry.name)
135
+ if ext != ".tar":
136
+ continue
137
+ class_ids.append(root)
138
+
139
+ return sorted(class_ids)
140
+
141
+ def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]:
142
+ root = self.get_root(root)
143
+ entries: List[_Entry] = []
144
+ class_ids = self._find_class_ids(root)
145
+
146
+ for class_index, class_id in enumerate(class_ids):
147
+ path = os.path.join(root, "blocks", f"{class_id}.log")
148
+ class_entries = []
149
+
150
+ try:
151
+ with open(path) as f:
152
+ for line in f:
153
+ line = line.rstrip()
154
+ block, filename = line.split(":")
155
+ block_offset = int(block[6:])
156
+ filename = filename[1:]
157
+
158
+ maybe_filename = None
159
+ if filename != "** Block of NULs **":
160
+ maybe_filename = filename
161
+ _, ext = os.path.splitext(filename)
162
+ # assert ext == ".JPEG"
163
+
164
+ class_entry = _ClassEntry(block_offset, maybe_filename)
165
+ class_entries.append(class_entry)
166
+ except OSError as e:
167
+ raise RuntimeError(f'can not read blocks file "{path}"') from e
168
+
169
+ assert class_entries[-1].maybe_filename is None
170
+
171
+ for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]):
172
+ assert class_entry1.block_offset <= class_entry2.block_offset
173
+ start_offset = 512 * class_entry1.block_offset
174
+ end_offset = 512 * class_entry2.block_offset
175
+ assert class_entry1.maybe_filename is not None
176
+ filename = class_entry1.maybe_filename
177
+ entry = _Entry(class_index, start_offset, end_offset, filename)
178
+ # Skip invalid image files (PIL throws UnidentifiedImageError)
179
+ if filename == "n06470073_47249.JPEG":
180
+ continue
181
+ entries.append(entry)
182
+
183
+ return entries, class_ids
184
+
185
+ def _load_extra(self, extra_path: str) -> np.ndarray:
186
+ extra_root = self._extra_root
187
+ extra_full_path = os.path.join(extra_root, extra_path)
188
+ return np.load(extra_full_path, mmap_mode="r")
189
+
190
+ def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
191
+ extra_root = self._extra_root
192
+ extra_full_path = os.path.join(extra_root, extra_path)
193
+ os.makedirs(extra_root, exist_ok=True)
194
+ np.save(extra_full_path, extra_array)
195
+
196
+ @property
197
+ def _tarballs_root(self) -> str:
198
+ return self.root
199
+
200
+ def find_class_id(self, class_index: int) -> str:
201
+ return str(self._class_ids[class_index])
202
+
203
+ def get_image_data(self, index: int) -> bytes:
204
+ entry = self._entries[index]
205
+ class_id = entry["class_id"]
206
+ class_mmap = self._mmap_tarball(class_id)
207
+
208
+ start_offset, end_offset = entry["start_offset"], entry["end_offset"]
209
+ try:
210
+ mapped_data = class_mmap[start_offset:end_offset]
211
+ data = mapped_data[512:] # Skip entry header block
212
+
213
+ if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B):
214
+ assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}"
215
+ with GzipFile(fileobj=BytesIO(data)) as g:
216
+ data = g.read()
217
+ except Exception as e:
218
+ raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e
219
+
220
+ return data
221
+
222
+ def get_target(self, index: int) -> Any:
223
+ return int(self._entries[index]["class_index"])
224
+
225
+ def get_targets(self) -> np.ndarray:
226
+ return self._entries["class_index"]
227
+
228
+ def get_class_id(self, index: int) -> str:
229
+ return str(self._entries[index]["class_id"])
230
+
231
+ def get_class_ids(self) -> np.ndarray:
232
+ return self._entries["class_id"]
233
+
234
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
235
+ with warnings.catch_warnings():
236
+ warnings.simplefilter("ignore")
237
+ return super().__getitem__(index)
238
+
239
+ def __len__(self) -> int:
240
+ return len(self._entries)
241
+
242
+ def _dump_entries(self, *args, **kwargs) -> None:
243
+ entries, class_ids = self._load_entries_class_ids(*args, **kwargs)
244
+
245
+ max_class_id_length, max_filename_length, max_class_index = -1, -1, -1
246
+ for entry in entries:
247
+ class_id = class_ids[entry.class_index]
248
+ max_class_index = max(entry.class_index, max_class_index)
249
+ max_class_id_length = max(len(class_id), max_class_id_length)
250
+ max_filename_length = max(len(entry.filename), max_filename_length)
251
+
252
+ dtype = np.dtype(
253
+ [
254
+ ("class_index", "<u4"),
255
+ ("class_id", f"U{max_class_id_length}"),
256
+ ("start_offset", "<u4"),
257
+ ("end_offset", "<u4"),
258
+ ("filename", f"U{max_filename_length}"),
259
+ ]
260
+ )
261
+ sample_count = len(entries)
262
+ entries_array = np.empty(sample_count, dtype=dtype)
263
+ for i, entry in enumerate(entries):
264
+ class_index = entry.class_index
265
+ class_id = class_ids[class_index]
266
+ start_offset = entry.start_offset
267
+ end_offset = entry.end_offset
268
+ filename = entry.filename
269
+ entries_array[i] = (
270
+ class_index,
271
+ class_id,
272
+ start_offset,
273
+ end_offset,
274
+ filename,
275
+ )
276
+
277
+ entries_path = self._get_entries_path(*args, **kwargs)
278
+ self._save_extra(entries_array, entries_path)
279
+
280
+ def _dump_class_ids(self, *args, **kwargs) -> None:
281
+ entries_path = self._get_entries_path(*args, **kwargs)
282
+ entries_array = self._load_extra(entries_path)
283
+
284
+ max_class_id_length, max_class_index = -1, -1
285
+ for entry in entries_array:
286
+ class_index, class_id = entry["class_index"], entry["class_id"]
287
+ max_class_index = max(int(class_index), max_class_index)
288
+ max_class_id_length = max(len(str(class_id)), max_class_id_length)
289
+
290
+ class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}")
291
+ for entry in entries_array:
292
+ class_index, class_id = entry["class_index"], entry["class_id"]
293
+ class_ids_array[class_index] = class_id
294
+ class_ids_path = self._get_class_ids_path(*args, **kwargs)
295
+ self._save_extra(class_ids_array, class_ids_path)
296
+
297
+ def _dump_extra(self, *args, **kwargs) -> None:
298
+ self._dump_entries(*args, *kwargs)
299
+ self._dump_class_ids(*args, *kwargs)
300
+
301
+ def dump_extra(self, root: Optional[str] = None) -> None:
302
+ return self._dump_extra(root)
src/dinov2/data/loaders.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from enum import Enum
8
+ from typing import Any, Callable, List, Optional, TypeVar
9
+
10
+ import torch
11
+ from torch.utils.data import Sampler
12
+
13
+ from .datasets import ImageNet, ImageNet22k
14
+ from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ class SamplerType(Enum):
21
+ DISTRIBUTED = 0
22
+ EPOCH = 1
23
+ INFINITE = 2
24
+ SHARDED_INFINITE = 3
25
+ SHARDED_INFINITE_NEW = 4
26
+
27
+
28
+ def _make_bool_str(b: bool) -> str:
29
+ return "yes" if b else "no"
30
+
31
+
32
+ def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
33
+ def transform(sample):
34
+ image, target = sample
35
+ if image_transform is not None:
36
+ image = image_transform(image)
37
+ if target_transform is not None:
38
+ target = target_transform(target)
39
+ return image, target
40
+
41
+ return transform
42
+
43
+
44
+ def _parse_dataset_str(dataset_str: str):
45
+ tokens = dataset_str.split(":")
46
+
47
+ name = tokens[0]
48
+ kwargs = {}
49
+
50
+ for token in tokens[1:]:
51
+ key, value = token.split("=")
52
+ assert key in ("root", "extra", "split")
53
+ kwargs[key] = value
54
+
55
+ if name == "ImageNet":
56
+ class_ = ImageNet
57
+ if "split" in kwargs:
58
+ kwargs["split"] = ImageNet.Split[kwargs["split"]]
59
+ elif name == "ImageNet22k":
60
+ class_ = ImageNet22k
61
+ else:
62
+ raise ValueError(f'Unsupported dataset "{name}"')
63
+
64
+ return class_, kwargs
65
+
66
+
67
+ def make_dataset(
68
+ *,
69
+ dataset_str: str,
70
+ transform: Optional[Callable] = None,
71
+ target_transform: Optional[Callable] = None,
72
+ ):
73
+ """
74
+ Creates a dataset with the specified parameters.
75
+
76
+ Args:
77
+ dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
78
+ transform: A transform to apply to images.
79
+ target_transform: A transform to apply to targets.
80
+
81
+ Returns:
82
+ The created dataset.
83
+ """
84
+ logger.info(f'using dataset: "{dataset_str}"')
85
+
86
+ class_, kwargs = _parse_dataset_str(dataset_str)
87
+ dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
88
+
89
+ logger.info(f"# of dataset samples: {len(dataset):,d}")
90
+
91
+ # Aggregated datasets do not expose (yet) these attributes, so add them.
92
+ if not hasattr(dataset, "transform"):
93
+ setattr(dataset, "transform", transform)
94
+ if not hasattr(dataset, "target_transform"):
95
+ setattr(dataset, "target_transform", target_transform)
96
+
97
+ return dataset
98
+
99
+
100
+ def _make_sampler(
101
+ *,
102
+ dataset,
103
+ type: Optional[SamplerType] = None,
104
+ shuffle: bool = False,
105
+ seed: int = 0,
106
+ size: int = -1,
107
+ advance: int = 0,
108
+ ) -> Optional[Sampler]:
109
+ sample_count = len(dataset)
110
+
111
+ if type == SamplerType.INFINITE:
112
+ logger.info("sampler: infinite")
113
+ if size > 0:
114
+ raise ValueError("sampler size > 0 is invalid")
115
+ return InfiniteSampler(
116
+ sample_count=sample_count,
117
+ shuffle=shuffle,
118
+ seed=seed,
119
+ advance=advance,
120
+ )
121
+ elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
122
+ logger.info("sampler: sharded infinite")
123
+ if size > 0:
124
+ raise ValueError("sampler size > 0 is invalid")
125
+ # TODO: Remove support for old shuffling
126
+ use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
127
+ return ShardedInfiniteSampler(
128
+ sample_count=sample_count,
129
+ shuffle=shuffle,
130
+ seed=seed,
131
+ advance=advance,
132
+ use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
133
+ )
134
+ elif type == SamplerType.EPOCH:
135
+ logger.info("sampler: epoch")
136
+ if advance > 0:
137
+ raise NotImplementedError("sampler advance > 0 is not supported")
138
+ size = size if size > 0 else sample_count
139
+ logger.info(f"# of samples / epoch: {size:,d}")
140
+ return EpochSampler(
141
+ size=size,
142
+ sample_count=sample_count,
143
+ shuffle=shuffle,
144
+ seed=seed,
145
+ )
146
+ elif type == SamplerType.DISTRIBUTED:
147
+ logger.info("sampler: distributed")
148
+ if size > 0:
149
+ raise ValueError("sampler size > 0 is invalid")
150
+ if advance > 0:
151
+ raise ValueError("sampler advance > 0 is invalid")
152
+ return torch.utils.data.DistributedSampler(
153
+ dataset=dataset,
154
+ shuffle=shuffle,
155
+ seed=seed,
156
+ drop_last=False,
157
+ )
158
+
159
+ logger.info("sampler: none")
160
+ return None
161
+
162
+
163
+ T = TypeVar("T")
164
+
165
+
166
+ def make_data_loader(
167
+ *,
168
+ dataset,
169
+ batch_size: int,
170
+ num_workers: int,
171
+ shuffle: bool = True,
172
+ seed: int = 0,
173
+ sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
174
+ sampler_size: int = -1,
175
+ sampler_advance: int = 0,
176
+ drop_last: bool = True,
177
+ persistent_workers: bool = False,
178
+ collate_fn: Optional[Callable[[List[T]], Any]] = None,
179
+ ):
180
+ """
181
+ Creates a data loader with the specified parameters.
182
+
183
+ Args:
184
+ dataset: A dataset (third party, LaViDa or WebDataset).
185
+ batch_size: The size of batches to generate.
186
+ num_workers: The number of workers to use.
187
+ shuffle: Whether to shuffle samples.
188
+ seed: The random seed to use.
189
+ sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
190
+ sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
191
+ sampler_advance: How many samples to skip (when applicable).
192
+ drop_last: Whether the last non-full batch of data should be dropped.
193
+ persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
194
+ collate_fn: Function that performs batch collation
195
+ """
196
+
197
+ sampler = _make_sampler(
198
+ dataset=dataset,
199
+ type=sampler_type,
200
+ shuffle=shuffle,
201
+ seed=seed,
202
+ size=sampler_size,
203
+ advance=sampler_advance,
204
+ )
205
+
206
+ logger.info("using PyTorch data loader")
207
+ data_loader = torch.utils.data.DataLoader(
208
+ dataset,
209
+ sampler=sampler,
210
+ batch_size=batch_size,
211
+ num_workers=num_workers,
212
+ pin_memory=True,
213
+ drop_last=drop_last,
214
+ persistent_workers=persistent_workers,
215
+ collate_fn=collate_fn,
216
+ )
217
+
218
+ try:
219
+ logger.info(f"# of batches: {len(data_loader):,d}")
220
+ except TypeError: # data loader has no length
221
+ logger.info("infinite data loader")
222
+ return data_loader
src/dinov2/data/masking.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import math
8
+ import numpy as np
9
+
10
+
11
+ class MaskingGenerator:
12
+ def __init__(
13
+ self,
14
+ input_size,
15
+ num_masking_patches=None,
16
+ min_num_patches=4,
17
+ max_num_patches=None,
18
+ min_aspect=0.3,
19
+ max_aspect=None,
20
+ ):
21
+ if not isinstance(input_size, tuple):
22
+ input_size = (input_size,) * 2
23
+ self.height, self.width = input_size
24
+
25
+ self.num_patches = self.height * self.width
26
+ self.num_masking_patches = num_masking_patches
27
+
28
+ self.min_num_patches = min_num_patches
29
+ self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
30
+
31
+ max_aspect = max_aspect or 1 / min_aspect
32
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
33
+
34
+ def __repr__(self):
35
+ repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
36
+ self.height,
37
+ self.width,
38
+ self.min_num_patches,
39
+ self.max_num_patches,
40
+ self.num_masking_patches,
41
+ self.log_aspect_ratio[0],
42
+ self.log_aspect_ratio[1],
43
+ )
44
+ return repr_str
45
+
46
+ def get_shape(self):
47
+ return self.height, self.width
48
+
49
+ def _mask(self, mask, max_mask_patches):
50
+ delta = 0
51
+ for _ in range(10):
52
+ target_area = random.uniform(self.min_num_patches, max_mask_patches)
53
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
54
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
55
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
56
+ if w < self.width and h < self.height:
57
+ top = random.randint(0, self.height - h)
58
+ left = random.randint(0, self.width - w)
59
+
60
+ num_masked = mask[top : top + h, left : left + w].sum()
61
+ # Overlap
62
+ if 0 < h * w - num_masked <= max_mask_patches:
63
+ for i in range(top, top + h):
64
+ for j in range(left, left + w):
65
+ if mask[i, j] == 0:
66
+ mask[i, j] = 1
67
+ delta += 1
68
+
69
+ if delta > 0:
70
+ break
71
+ return delta
72
+
73
+ def __call__(self, num_masking_patches=0):
74
+ mask = np.zeros(shape=self.get_shape(), dtype=bool)
75
+ mask_count = 0
76
+ while mask_count < num_masking_patches:
77
+ max_mask_patches = num_masking_patches - mask_count
78
+ max_mask_patches = min(max_mask_patches, self.max_num_patches)
79
+
80
+ delta = self._mask(mask, max_mask_patches)
81
+ if delta == 0:
82
+ break
83
+ else:
84
+ mask_count += delta
85
+
86
+ return mask
src/dinov2/data/samplers.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ from typing import Any, Optional
8
+ import warnings
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.utils.data.sampler import Sampler
13
+
14
+ import dinov2.distributed as distributed
15
+
16
+
17
+ class EpochSampler(Sampler):
18
+ def __init__(
19
+ self,
20
+ *,
21
+ size: int,
22
+ sample_count: int,
23
+ shuffle: bool = False,
24
+ seed: int = 0,
25
+ start: Optional[int] = None,
26
+ step: Optional[int] = None,
27
+ ):
28
+ self._size = size
29
+ self._sample_count = sample_count
30
+ self._shuffle = shuffle
31
+ self._seed = seed
32
+ self._start = distributed.get_global_rank() if start is None else start
33
+ self._step = distributed.get_global_size() if step is None else step
34
+ self._epoch = 0
35
+
36
+ def __iter__(self):
37
+ count = (self._size + self._sample_count - 1) // self._sample_count
38
+ tiled_indices = np.tile(np.arange(self._sample_count), count)
39
+ if self._shuffle:
40
+ seed = self._seed * self._epoch if self._seed != 0 else self._epoch
41
+ rng = np.random.default_rng(seed)
42
+ iterable = rng.choice(tiled_indices, self._size, replace=False)
43
+ else:
44
+ iterable = tiled_indices[: self._size]
45
+
46
+ yield from itertools.islice(iterable, self._start, None, self._step)
47
+
48
+ def __len__(self):
49
+ return (self._size - self._start + self._step - 1) // self._step
50
+
51
+ def set_epoch(self, epoch):
52
+ self._epoch = epoch
53
+
54
+
55
+ def _get_numpy_dtype(size: int) -> Any:
56
+ return np.int32 if size <= 2**31 else np.int64
57
+
58
+
59
+ def _get_torch_dtype(size: int) -> Any:
60
+ return torch.int32 if size <= 2**31 else torch.int64
61
+
62
+
63
+ def _generate_randperm_indices(*, size: int, generator: torch.Generator):
64
+ """Generate the indices of a random permutation."""
65
+ dtype = _get_torch_dtype(size)
66
+ # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
67
+ perm = torch.arange(size, dtype=dtype)
68
+ for i in range(size):
69
+ j = torch.randint(i, size, size=(1,), generator=generator).item()
70
+
71
+ # Always swap even if no-op
72
+ value = perm[j].item()
73
+ perm[j] = perm[i].item()
74
+ perm[i] = value
75
+ yield value
76
+
77
+
78
+ class InfiniteSampler(Sampler):
79
+ def __init__(
80
+ self,
81
+ *,
82
+ sample_count: int,
83
+ shuffle: bool = False,
84
+ seed: int = 0,
85
+ start: Optional[int] = None,
86
+ step: Optional[int] = None,
87
+ advance: int = 0,
88
+ ):
89
+ self._sample_count = sample_count
90
+ self._seed = seed
91
+ self._shuffle = shuffle
92
+ self._start = distributed.get_global_rank() if start is None else start
93
+ self._step = distributed.get_global_size() if step is None else step
94
+ self._advance = advance
95
+
96
+ def __iter__(self):
97
+ if self._shuffle:
98
+ iterator = self._shuffled_iterator()
99
+ else:
100
+ iterator = self._iterator()
101
+
102
+ yield from itertools.islice(iterator, self._advance, None)
103
+
104
+ def _iterator(self):
105
+ assert not self._shuffle
106
+
107
+ while True:
108
+ iterable = range(self._sample_count)
109
+ yield from itertools.islice(iterable, self._start, None, self._step)
110
+
111
+ def _shuffled_iterator(self):
112
+ assert self._shuffle
113
+
114
+ # Instantiate a generator here (rather than in the ctor) to keep the class
115
+ # picklable (requirement of mp.spawn)
116
+ generator = torch.Generator().manual_seed(self._seed)
117
+
118
+ while True:
119
+ iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
120
+ yield from itertools.islice(iterable, self._start, None, self._step)
121
+
122
+
123
+ # The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
124
+ # but avoids a full in-place random permutation generation.
125
+ def _shuffle_tensor_slice(
126
+ *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
127
+ ) -> np.ndarray:
128
+ stop = len(tensor)
129
+ count = stop // step
130
+ drop_count = stop - step * count
131
+ if drop_count:
132
+ warnings.warn(f"# of dropped samples: {drop_count}")
133
+
134
+ dtype = _get_numpy_dtype(stop)
135
+ result = np.empty(count, dtype=dtype)
136
+
137
+ for i in range(count):
138
+ j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
139
+
140
+ result[i] = result[j]
141
+ result[j] = tensor[start + i * step].item()
142
+
143
+ return result
144
+
145
+
146
+ def _new_shuffle_tensor_slice(
147
+ *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
148
+ ) -> np.ndarray:
149
+ stop = len(tensor)
150
+ count = stop // step
151
+ dtype = torch.int64 # Needed for using randperm result as indices
152
+ count = stop // step
153
+ drop_count = stop - step * count
154
+ if drop_count:
155
+ warnings.warn(f"# of dropped samples: {drop_count}")
156
+ indices = torch.randperm(count, dtype=dtype, generator=generator)
157
+ return tensor[start::step][indices].numpy()
158
+
159
+
160
+ def _make_seed(seed: int, start: int, iter_count: int) -> int:
161
+ # NOTE: Tried a few variants (including iter_count << 32), this one worked best.
162
+ return seed + start + (iter_count << 24)
163
+
164
+
165
+ class ShardedInfiniteSampler(Sampler):
166
+ def __init__(
167
+ self,
168
+ *,
169
+ sample_count: int,
170
+ shuffle: bool = False,
171
+ seed: int = 0,
172
+ start: Optional[int] = None,
173
+ step: Optional[int] = None,
174
+ advance: int = 0,
175
+ use_new_shuffle_tensor_slice: bool = False,
176
+ ):
177
+ self._sample_count = sample_count
178
+ self._seed = seed
179
+ self._shuffle = shuffle
180
+ self._start = distributed.get_global_rank() if start is None else start
181
+ self._step = distributed.get_global_size() if step is None else step
182
+ self._advance = advance
183
+ self._iter_count = 0
184
+ self._shuffle_tensor_slice_fn = (
185
+ _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
186
+ )
187
+
188
+ def __iter__(self):
189
+ iter_count = self._advance // self._sample_count
190
+ if iter_count > 0:
191
+ self._advance -= iter_count * self._sample_count
192
+ self._iter_count += iter_count
193
+
194
+ if self._shuffle:
195
+ iterator = self._shuffled_iterator()
196
+ else:
197
+ iterator = self._iterator()
198
+
199
+ yield from itertools.islice(iterator, self._advance, None)
200
+
201
+ def _iterator(self):
202
+ assert not self._shuffle
203
+
204
+ while True:
205
+ iterable = range(self._sample_count)
206
+ yield from itertools.islice(iterable, self._start, None, self._step)
207
+
208
+ def _shuffled_iterator(self):
209
+ assert self._shuffle
210
+
211
+ # Instantiate a generator here (rather than in the ctor) to be keep the class
212
+ # picklable (requirement of mp.spawn)
213
+ generator = torch.Generator()
214
+
215
+ # Always shuffle everything first
216
+ generator.manual_seed(self._seed)
217
+ dtype = _get_torch_dtype(self._sample_count)
218
+ perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
219
+
220
+ while True:
221
+ # Re-seed on each iteration to allow skipping whole permutations
222
+ seed = _make_seed(self._seed, self._start, self._iter_count)
223
+ generator.manual_seed(seed)
224
+
225
+ iterable = self._shuffle_tensor_slice_fn(
226
+ tensor=perm, start=self._start, step=self._step, generator=generator
227
+ )
228
+ yield from iterable
229
+ self._iter_count += 1
src/dinov2/data/transforms.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Sequence
7
+
8
+ import torch
9
+ from torchvision import transforms
10
+
11
+
12
+ class GaussianBlur(transforms.RandomApply):
13
+ """
14
+ Apply Gaussian Blur to the PIL image.
15
+ """
16
+
17
+ def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
18
+ # NOTE: torchvision is applying 1 - probability to return the original image
19
+ keep_p = 1 - p
20
+ transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
21
+ super().__init__(transforms=[transform], p=keep_p)
22
+
23
+
24
+ class MaybeToTensor(transforms.ToTensor):
25
+ """
26
+ Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
27
+ """
28
+
29
+ def __call__(self, pic):
30
+ """
31
+ Args:
32
+ pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
33
+ Returns:
34
+ Tensor: Converted image.
35
+ """
36
+ if isinstance(pic, torch.Tensor):
37
+ return pic
38
+ return super().__call__(pic)
39
+
40
+
41
+ # Use timm's names
42
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
43
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
44
+
45
+
46
+ def make_normalize_transform(
47
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
48
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
49
+ ) -> transforms.Normalize:
50
+ return transforms.Normalize(mean=mean, std=std)
51
+
52
+
53
+ # This roughly matches torchvision's preset for classification training:
54
+ # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
55
+ def make_classification_train_transform(
56
+ *,
57
+ crop_size: int = 224,
58
+ interpolation=transforms.InterpolationMode.BICUBIC,
59
+ hflip_prob: float = 0.5,
60
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
61
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
62
+ ):
63
+ transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
64
+ if hflip_prob > 0.0:
65
+ transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
66
+ transforms_list.extend(
67
+ [
68
+ MaybeToTensor(),
69
+ make_normalize_transform(mean=mean, std=std),
70
+ ]
71
+ )
72
+ return transforms.Compose(transforms_list)
73
+
74
+
75
+ # This matches (roughly) torchvision's preset for classification evaluation:
76
+ # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
77
+ def make_classification_eval_transform(
78
+ *,
79
+ resize_size: int = 256,
80
+ interpolation=transforms.InterpolationMode.BICUBIC,
81
+ crop_size: int = 224,
82
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
83
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
84
+ ) -> transforms.Compose:
85
+ transforms_list = [
86
+ transforms.Resize(resize_size, interpolation=interpolation),
87
+ transforms.CenterCrop(crop_size),
88
+ MaybeToTensor(),
89
+ make_normalize_transform(mean=mean, std=std),
90
+ ]
91
+ return transforms.Compose(transforms_list)
src/dinov2/distributed/__init__.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import random
8
+ import re
9
+ import socket
10
+ from typing import Dict, List
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+
15
+ _LOCAL_RANK = -1
16
+ _LOCAL_WORLD_SIZE = -1
17
+
18
+
19
+ def is_enabled() -> bool:
20
+ """
21
+ Returns:
22
+ True if distributed training is enabled
23
+ """
24
+ return dist.is_available() and dist.is_initialized()
25
+
26
+
27
+ def get_global_size() -> int:
28
+ """
29
+ Returns:
30
+ The number of processes in the process group
31
+ """
32
+ return dist.get_world_size() if is_enabled() else 1
33
+
34
+
35
+ def get_global_rank() -> int:
36
+ """
37
+ Returns:
38
+ The rank of the current process within the global process group.
39
+ """
40
+ return dist.get_rank() if is_enabled() else 0
41
+
42
+
43
+ def get_local_rank() -> int:
44
+ """
45
+ Returns:
46
+ The rank of the current process within the local (per-machine) process group.
47
+ """
48
+ if not is_enabled():
49
+ return 0
50
+ assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
51
+ return _LOCAL_RANK
52
+
53
+
54
+ def get_local_size() -> int:
55
+ """
56
+ Returns:
57
+ The size of the per-machine process group,
58
+ i.e. the number of processes per machine.
59
+ """
60
+ if not is_enabled():
61
+ return 1
62
+ assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
63
+ return _LOCAL_WORLD_SIZE
64
+
65
+
66
+ def is_main_process() -> bool:
67
+ """
68
+ Returns:
69
+ True if the current process is the main one.
70
+ """
71
+ return get_global_rank() == 0
72
+
73
+
74
+ def _restrict_print_to_main_process() -> None:
75
+ """
76
+ This function disables printing when not in the main process
77
+ """
78
+ import builtins as __builtin__
79
+
80
+ builtin_print = __builtin__.print
81
+
82
+ def print(*args, **kwargs):
83
+ force = kwargs.pop("force", False)
84
+ if is_main_process() or force:
85
+ builtin_print(*args, **kwargs)
86
+
87
+ __builtin__.print = print
88
+
89
+
90
+ def _get_master_port(seed: int = 0) -> int:
91
+ MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
92
+
93
+ master_port_str = os.environ.get("MASTER_PORT")
94
+ if master_port_str is None:
95
+ rng = random.Random(seed)
96
+ return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
97
+
98
+ return int(master_port_str)
99
+
100
+
101
+ def _get_available_port() -> int:
102
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
103
+ # A "" host address means INADDR_ANY i.e. binding to all interfaces.
104
+ # Note this is not compatible with IPv6.
105
+ s.bind(("", 0))
106
+ port = s.getsockname()[1]
107
+ return port
108
+
109
+
110
+ _TORCH_DISTRIBUTED_ENV_VARS = (
111
+ "MASTER_ADDR",
112
+ "MASTER_PORT",
113
+ "RANK",
114
+ "WORLD_SIZE",
115
+ "LOCAL_RANK",
116
+ "LOCAL_WORLD_SIZE",
117
+ )
118
+
119
+
120
+ def _collect_env_vars() -> Dict[str, str]:
121
+ return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
122
+
123
+
124
+ def _is_slurm_job_process() -> bool:
125
+ return "SLURM_JOB_ID" in os.environ
126
+
127
+
128
+ def _parse_slurm_node_list(s: str) -> List[str]:
129
+ nodes = []
130
+ # Extract "hostname", "hostname[1-2,3,4-5]," substrings
131
+ p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
132
+ for m in p.finditer(s):
133
+ prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
134
+ for suffix in suffixes.split(","):
135
+ span = suffix.split("-")
136
+ if len(span) == 1:
137
+ nodes.append(prefix + suffix)
138
+ else:
139
+ width = len(span[0])
140
+ start, end = int(span[0]), int(span[1]) + 1
141
+ nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
142
+ return nodes
143
+
144
+
145
+ def _check_env_variable(key: str, new_value: str):
146
+ # Only check for difference with preset environment variables
147
+ if key in os.environ and os.environ[key] != new_value:
148
+ raise RuntimeError(f"Cannot export environment variables as {key} is already set")
149
+
150
+
151
+ class _TorchDistributedEnvironment:
152
+ def __init__(self):
153
+ self.master_addr = "127.0.0.1"
154
+ self.master_port = 0
155
+ self.rank = -1
156
+ self.world_size = -1
157
+ self.local_rank = -1
158
+ self.local_world_size = -1
159
+
160
+ if _is_slurm_job_process():
161
+ return self._set_from_slurm_env()
162
+
163
+ env_vars = _collect_env_vars()
164
+ if not env_vars:
165
+ # Environment is not set
166
+ pass
167
+ elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
168
+ # Environment is fully set
169
+ return self._set_from_preset_env()
170
+ else:
171
+ # Environment is partially set
172
+ collected_env_vars = ", ".join(env_vars.keys())
173
+ raise RuntimeError(f"Partially set environment: {collected_env_vars}")
174
+
175
+ if torch.cuda.device_count() > 0:
176
+ return self._set_from_local()
177
+
178
+ raise RuntimeError("Can't initialize PyTorch distributed environment")
179
+
180
+ # Slurm job created with sbatch, submitit, etc...
181
+ def _set_from_slurm_env(self):
182
+ # logger.info("Initialization from Slurm environment")
183
+ job_id = int(os.environ["SLURM_JOB_ID"])
184
+ node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
185
+ nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
186
+ assert len(nodes) == node_count
187
+
188
+ self.master_addr = nodes[0]
189
+ self.master_port = _get_master_port(seed=job_id)
190
+ self.rank = int(os.environ["SLURM_PROCID"])
191
+ self.world_size = int(os.environ["SLURM_NTASKS"])
192
+ assert self.rank < self.world_size
193
+ self.local_rank = int(os.environ["SLURM_LOCALID"])
194
+ self.local_world_size = self.world_size // node_count
195
+ assert self.local_rank < self.local_world_size
196
+
197
+ # Single node job with preset environment (i.e. torchrun)
198
+ def _set_from_preset_env(self):
199
+ # logger.info("Initialization from preset environment")
200
+ self.master_addr = os.environ["MASTER_ADDR"]
201
+ self.master_port = os.environ["MASTER_PORT"]
202
+ self.rank = int(os.environ["RANK"])
203
+ self.world_size = int(os.environ["WORLD_SIZE"])
204
+ assert self.rank < self.world_size
205
+ self.local_rank = int(os.environ["LOCAL_RANK"])
206
+ self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
207
+ assert self.local_rank < self.local_world_size
208
+
209
+ # Single node and GPU job (i.e. local script run)
210
+ def _set_from_local(self):
211
+ # logger.info("Initialization from local")
212
+ self.master_addr = "127.0.0.1"
213
+ self.master_port = _get_available_port()
214
+ self.rank = 0
215
+ self.world_size = 1
216
+ self.local_rank = 0
217
+ self.local_world_size = 1
218
+
219
+ def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
220
+ # See the "Environment variable initialization" section from
221
+ # https://pytorch.org/docs/stable/distributed.html for the complete list of
222
+ # environment variables required for the env:// initialization method.
223
+ env_vars = {
224
+ "MASTER_ADDR": self.master_addr,
225
+ "MASTER_PORT": str(self.master_port),
226
+ "RANK": str(self.rank),
227
+ "WORLD_SIZE": str(self.world_size),
228
+ "LOCAL_RANK": str(self.local_rank),
229
+ "LOCAL_WORLD_SIZE": str(self.local_world_size),
230
+ }
231
+ if not overwrite:
232
+ for k, v in env_vars.items():
233
+ _check_env_variable(k, v)
234
+
235
+ os.environ.update(env_vars)
236
+ return self
237
+
238
+
239
+ def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
240
+ """Enable distributed mode
241
+
242
+ Args:
243
+ set_cuda_current_device: If True, call torch.cuda.set_device() to set the
244
+ current PyTorch CUDA device to the one matching the local rank.
245
+ overwrite: If True, overwrites already set variables. Else fails.
246
+ """
247
+
248
+ global _LOCAL_RANK, _LOCAL_WORLD_SIZE
249
+ if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
250
+ raise RuntimeError("Distributed mode has already been enabled")
251
+ torch_env = _TorchDistributedEnvironment()
252
+ torch_env.export(overwrite=overwrite)
253
+
254
+ if set_cuda_current_device:
255
+ torch.cuda.set_device(torch_env.local_rank)
256
+
257
+ if allow_nccl_timeout:
258
+ # This allows to use torch distributed timeout in a NCCL backend
259
+ key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
260
+ if not overwrite:
261
+ _check_env_variable(key, value)
262
+ os.environ[key] = value
263
+
264
+ dist.init_process_group(backend="nccl")
265
+ dist.barrier()
266
+
267
+ # Finalize setup
268
+ _LOCAL_RANK = torch_env.local_rank
269
+ _LOCAL_WORLD_SIZE = torch_env.local_world_size
270
+ _restrict_print_to_main_process()
src/dinov2/eval/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
src/dinov2/eval/depth/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
src/dinov2/eval/depth/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .backbones import * # noqa: F403
7
+ from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss
8
+ from .decode_heads import * # noqa: F403
9
+ from .depther import * # noqa: F403
10
+ from .losses import * # noqa: F403
src/dinov2/eval/depth/models/backbones/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .vision_transformer import DinoVisionTransformer
src/dinov2/eval/depth/models/backbones/vision_transformer.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from mmcv.runner import BaseModule
7
+
8
+ from ..builder import BACKBONES
9
+
10
+
11
+ @BACKBONES.register_module()
12
+ class DinoVisionTransformer(BaseModule):
13
+ """Vision Transformer."""
14
+
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__()
src/dinov2/eval/depth/models/builder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import warnings
7
+
8
+ from mmcv.cnn import MODELS as MMCV_MODELS
9
+ from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
10
+ from mmcv.utils import Registry
11
+
12
+ MODELS = Registry("models", parent=MMCV_MODELS)
13
+ ATTENTION = Registry("attention", parent=MMCV_ATTENTION)
14
+
15
+
16
+ BACKBONES = MODELS
17
+ NECKS = MODELS
18
+ HEADS = MODELS
19
+ LOSSES = MODELS
20
+ DEPTHER = MODELS
21
+
22
+
23
+ def build_backbone(cfg):
24
+ """Build backbone."""
25
+ return BACKBONES.build(cfg)
26
+
27
+
28
+ def build_neck(cfg):
29
+ """Build neck."""
30
+ return NECKS.build(cfg)
31
+
32
+
33
+ def build_head(cfg):
34
+ """Build head."""
35
+ return HEADS.build(cfg)
36
+
37
+
38
+ def build_loss(cfg):
39
+ """Build loss."""
40
+ return LOSSES.build(cfg)
41
+
42
+
43
+ def build_depther(cfg, train_cfg=None, test_cfg=None):
44
+ """Build depther."""
45
+ if train_cfg is not None or test_cfg is not None:
46
+ warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning)
47
+ assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field "
48
+ assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field "
49
+ return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
src/dinov2/eval/depth/models/decode_heads/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dpt_head import DPTHead
7
+ from .linear_head import BNHead
src/dinov2/eval/depth/models/decode_heads/decode_head.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import copy
7
+ from abc import ABCMeta, abstractmethod
8
+
9
+ import mmcv
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from mmcv.runner import BaseModule, auto_fp16, force_fp32
14
+
15
+ from ...ops import resize
16
+ from ..builder import build_loss
17
+
18
+
19
+ class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta):
20
+ """Base class for BaseDecodeHead.
21
+
22
+ Args:
23
+ in_channels (List): Input channels.
24
+ channels (int): Channels after modules, before conv_depth.
25
+ conv_cfg (dict|None): Config of conv layers. Default: None.
26
+ act_cfg (dict): Config of activation layers.
27
+ Default: dict(type='ReLU')
28
+ loss_decode (dict): Config of decode loss.
29
+ Default: dict(type='SigLoss').
30
+ sampler (dict|None): The config of depth map sampler.
31
+ Default: None.
32
+ align_corners (bool): align_corners argument of F.interpolate.
33
+ Default: False.
34
+ min_depth (int): Min depth in dataset setting.
35
+ Default: 1e-3.
36
+ max_depth (int): Max depth in dataset setting.
37
+ Default: None.
38
+ norm_cfg (dict|None): Config of norm layers.
39
+ Default: None.
40
+ classify (bool): Whether predict depth in a cls.-reg. manner.
41
+ Default: False.
42
+ n_bins (int): The number of bins used in cls. step.
43
+ Default: 256.
44
+ bins_strategy (str): The discrete strategy used in cls. step.
45
+ Default: 'UD'.
46
+ norm_strategy (str): The norm strategy on cls. probability
47
+ distribution. Default: 'linear'
48
+ scale_up (str): Whether predict depth in a scale-up manner.
49
+ Default: False.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ in_channels,
55
+ channels=96,
56
+ conv_cfg=None,
57
+ act_cfg=dict(type="ReLU"),
58
+ loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10),
59
+ sampler=None,
60
+ align_corners=False,
61
+ min_depth=1e-3,
62
+ max_depth=None,
63
+ norm_cfg=None,
64
+ classify=False,
65
+ n_bins=256,
66
+ bins_strategy="UD",
67
+ norm_strategy="linear",
68
+ scale_up=False,
69
+ ):
70
+ super(DepthBaseDecodeHead, self).__init__()
71
+
72
+ self.in_channels = in_channels
73
+ self.channels = channels
74
+ self.conv_cfg = conv_cfg
75
+ self.act_cfg = act_cfg
76
+ if isinstance(loss_decode, dict):
77
+ self.loss_decode = build_loss(loss_decode)
78
+ elif isinstance(loss_decode, (list, tuple)):
79
+ self.loss_decode = nn.ModuleList()
80
+ for loss in loss_decode:
81
+ self.loss_decode.append(build_loss(loss))
82
+ self.align_corners = align_corners
83
+ self.min_depth = min_depth
84
+ self.max_depth = max_depth
85
+ self.norm_cfg = norm_cfg
86
+ self.classify = classify
87
+ self.n_bins = n_bins
88
+ self.scale_up = scale_up
89
+
90
+ if self.classify:
91
+ assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
92
+ assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
93
+
94
+ self.bins_strategy = bins_strategy
95
+ self.norm_strategy = norm_strategy
96
+ self.softmax = nn.Softmax(dim=1)
97
+ self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
98
+ else:
99
+ self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
100
+
101
+ self.fp16_enabled = False
102
+ self.relu = nn.ReLU()
103
+ self.sigmoid = nn.Sigmoid()
104
+
105
+ def extra_repr(self):
106
+ """Extra repr."""
107
+ s = f"align_corners={self.align_corners}"
108
+ return s
109
+
110
+ @auto_fp16()
111
+ @abstractmethod
112
+ def forward(self, inputs, img_metas):
113
+ """Placeholder of forward function."""
114
+ pass
115
+
116
+ def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg):
117
+ """Forward function for training.
118
+ Args:
119
+ inputs (list[Tensor]): List of multi-level img features.
120
+ img_metas (list[dict]): List of image info dict where each dict
121
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
122
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
123
+ For details on the values of these keys see
124
+ `depth/datasets/pipelines/formatting.py:Collect`.
125
+ depth_gt (Tensor): GT depth
126
+ train_cfg (dict): The training config.
127
+
128
+ Returns:
129
+ dict[str, Tensor]: a dictionary of loss components
130
+ """
131
+ depth_pred = self.forward(inputs, img_metas)
132
+ losses = self.losses(depth_pred, depth_gt)
133
+
134
+ log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
135
+ losses.update(**log_imgs)
136
+
137
+ return losses
138
+
139
+ def forward_test(self, inputs, img_metas, test_cfg):
140
+ """Forward function for testing.
141
+ Args:
142
+ inputs (list[Tensor]): List of multi-level img features.
143
+ img_metas (list[dict]): List of image info dict where each dict
144
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
145
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
146
+ For details on the values of these keys see
147
+ `depth/datasets/pipelines/formatting.py:Collect`.
148
+ test_cfg (dict): The testing config.
149
+
150
+ Returns:
151
+ Tensor: Output depth map.
152
+ """
153
+ return self.forward(inputs, img_metas)
154
+
155
+ def depth_pred(self, feat):
156
+ """Prediction each pixel."""
157
+ if self.classify:
158
+ logit = self.conv_depth(feat)
159
+
160
+ if self.bins_strategy == "UD":
161
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
162
+ elif self.bins_strategy == "SID":
163
+ bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
164
+
165
+ # following Adabins, default linear
166
+ if self.norm_strategy == "linear":
167
+ logit = torch.relu(logit)
168
+ eps = 0.1
169
+ logit = logit + eps
170
+ logit = logit / logit.sum(dim=1, keepdim=True)
171
+ elif self.norm_strategy == "softmax":
172
+ logit = torch.softmax(logit, dim=1)
173
+ elif self.norm_strategy == "sigmoid":
174
+ logit = torch.sigmoid(logit)
175
+ logit = logit / logit.sum(dim=1, keepdim=True)
176
+
177
+ output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
178
+
179
+ else:
180
+ if self.scale_up:
181
+ output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
182
+ else:
183
+ output = self.relu(self.conv_depth(feat)) + self.min_depth
184
+ return output
185
+
186
+ @force_fp32(apply_to=("depth_pred",))
187
+ def losses(self, depth_pred, depth_gt):
188
+ """Compute depth loss."""
189
+ loss = dict()
190
+ depth_pred = resize(
191
+ input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
192
+ )
193
+ if not isinstance(self.loss_decode, nn.ModuleList):
194
+ losses_decode = [self.loss_decode]
195
+ else:
196
+ losses_decode = self.loss_decode
197
+ for loss_decode in losses_decode:
198
+ if loss_decode.loss_name not in loss:
199
+ loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
200
+ else:
201
+ loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
202
+ return loss
203
+
204
+ def log_images(self, img_path, depth_pred, depth_gt, img_meta):
205
+ show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
206
+ show_img = show_img.numpy().astype(np.float32)
207
+ show_img = mmcv.imdenormalize(
208
+ show_img,
209
+ img_meta["img_norm_cfg"]["mean"],
210
+ img_meta["img_norm_cfg"]["std"],
211
+ img_meta["img_norm_cfg"]["to_rgb"],
212
+ )
213
+ show_img = np.clip(show_img, 0, 255)
214
+ show_img = show_img.astype(np.uint8)
215
+ show_img = show_img[:, :, ::-1]
216
+ show_img = show_img.transpose(0, 2, 1)
217
+ show_img = show_img.transpose(1, 0, 2)
218
+
219
+ depth_pred = depth_pred / torch.max(depth_pred)
220
+ depth_gt = depth_gt / torch.max(depth_gt)
221
+
222
+ depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
223
+ depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
224
+
225
+ return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
src/dinov2/eval/depth/models/decode_heads/dpt_head.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from mmcv.cnn import ConvModule, Linear, build_activation_layer
11
+ from mmcv.runner import BaseModule
12
+
13
+ from ...ops import resize
14
+ from ..builder import HEADS
15
+ from .decode_head import DepthBaseDecodeHead
16
+
17
+
18
+ class Interpolate(nn.Module):
19
+ def __init__(self, scale_factor, mode, align_corners=False):
20
+ super(Interpolate, self).__init__()
21
+ self.interp = nn.functional.interpolate
22
+ self.scale_factor = scale_factor
23
+ self.mode = mode
24
+ self.align_corners = align_corners
25
+
26
+ def forward(self, x):
27
+ x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
28
+ return x
29
+
30
+
31
+ class HeadDepth(nn.Module):
32
+ def __init__(self, features):
33
+ super(HeadDepth, self).__init__()
34
+ self.head = nn.Sequential(
35
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
36
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
37
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
38
+ nn.ReLU(),
39
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
40
+ )
41
+
42
+ def forward(self, x):
43
+ x = self.head(x)
44
+ return x
45
+
46
+
47
+ class ReassembleBlocks(BaseModule):
48
+ """ViTPostProcessBlock, process cls_token in ViT backbone output and
49
+ rearrange the feature vector to feature map.
50
+ Args:
51
+ in_channels (int): ViT feature channels. Default: 768.
52
+ out_channels (List): output channels of each stage.
53
+ Default: [96, 192, 384, 768].
54
+ readout_type (str): Type of readout operation. Default: 'ignore'.
55
+ patch_size (int): The patch size. Default: 16.
56
+ init_cfg (dict, optional): Initialization config dict. Default: None.
57
+ """
58
+
59
+ def __init__(
60
+ self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None
61
+ ):
62
+ super(ReassembleBlocks, self).__init__(init_cfg)
63
+
64
+ assert readout_type in ["ignore", "add", "project"]
65
+ self.readout_type = readout_type
66
+ self.patch_size = patch_size
67
+
68
+ self.projects = nn.ModuleList(
69
+ [
70
+ ConvModule(
71
+ in_channels=in_channels,
72
+ out_channels=out_channel,
73
+ kernel_size=1,
74
+ act_cfg=None,
75
+ )
76
+ for out_channel in out_channels
77
+ ]
78
+ )
79
+
80
+ self.resize_layers = nn.ModuleList(
81
+ [
82
+ nn.ConvTranspose2d(
83
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
84
+ ),
85
+ nn.ConvTranspose2d(
86
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
87
+ ),
88
+ nn.Identity(),
89
+ nn.Conv2d(
90
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
91
+ ),
92
+ ]
93
+ )
94
+ if self.readout_type == "project":
95
+ self.readout_projects = nn.ModuleList()
96
+ for _ in range(len(self.projects)):
97
+ self.readout_projects.append(
98
+ nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU")))
99
+ )
100
+
101
+ def forward(self, inputs):
102
+ assert isinstance(inputs, list)
103
+ out = []
104
+ for i, x in enumerate(inputs):
105
+ assert len(x) == 2
106
+ x, cls_token = x[0], x[1]
107
+ feature_shape = x.shape
108
+ if self.readout_type == "project":
109
+ x = x.flatten(2).permute((0, 2, 1))
110
+ readout = cls_token.unsqueeze(1).expand_as(x)
111
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
112
+ x = x.permute(0, 2, 1).reshape(feature_shape)
113
+ elif self.readout_type == "add":
114
+ x = x.flatten(2) + cls_token.unsqueeze(-1)
115
+ x = x.reshape(feature_shape)
116
+ else:
117
+ pass
118
+ x = self.projects[i](x)
119
+ x = self.resize_layers[i](x)
120
+ out.append(x)
121
+ return out
122
+
123
+
124
+ class PreActResidualConvUnit(BaseModule):
125
+ """ResidualConvUnit, pre-activate residual unit.
126
+ Args:
127
+ in_channels (int): number of channels in the input feature map.
128
+ act_cfg (dict): dictionary to construct and config activation layer.
129
+ norm_cfg (dict): dictionary to construct and config norm layer.
130
+ stride (int): stride of the first block. Default: 1
131
+ dilation (int): dilation rate for convs layers. Default: 1.
132
+ init_cfg (dict, optional): Initialization config dict. Default: None.
133
+ """
134
+
135
+ def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None):
136
+ super(PreActResidualConvUnit, self).__init__(init_cfg)
137
+
138
+ self.conv1 = ConvModule(
139
+ in_channels,
140
+ in_channels,
141
+ 3,
142
+ stride=stride,
143
+ padding=dilation,
144
+ dilation=dilation,
145
+ norm_cfg=norm_cfg,
146
+ act_cfg=act_cfg,
147
+ bias=False,
148
+ order=("act", "conv", "norm"),
149
+ )
150
+
151
+ self.conv2 = ConvModule(
152
+ in_channels,
153
+ in_channels,
154
+ 3,
155
+ padding=1,
156
+ norm_cfg=norm_cfg,
157
+ act_cfg=act_cfg,
158
+ bias=False,
159
+ order=("act", "conv", "norm"),
160
+ )
161
+
162
+ def forward(self, inputs):
163
+ inputs_ = inputs.clone()
164
+ x = self.conv1(inputs)
165
+ x = self.conv2(x)
166
+ return x + inputs_
167
+
168
+
169
+ class FeatureFusionBlock(BaseModule):
170
+ """FeatureFusionBlock, merge feature map from different stages.
171
+ Args:
172
+ in_channels (int): Input channels.
173
+ act_cfg (dict): The activation config for ResidualConvUnit.
174
+ norm_cfg (dict): Config dict for normalization layer.
175
+ expand (bool): Whether expand the channels in post process block.
176
+ Default: False.
177
+ align_corners (bool): align_corner setting for bilinear upsample.
178
+ Default: True.
179
+ init_cfg (dict, optional): Initialization config dict. Default: None.
180
+ """
181
+
182
+ def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None):
183
+ super(FeatureFusionBlock, self).__init__(init_cfg)
184
+
185
+ self.in_channels = in_channels
186
+ self.expand = expand
187
+ self.align_corners = align_corners
188
+
189
+ self.out_channels = in_channels
190
+ if self.expand:
191
+ self.out_channels = in_channels // 2
192
+
193
+ self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True)
194
+
195
+ self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
196
+ self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
197
+
198
+ def forward(self, *inputs):
199
+ x = inputs[0]
200
+ if len(inputs) == 2:
201
+ if x.shape != inputs[1].shape:
202
+ res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
203
+ else:
204
+ res = inputs[1]
205
+ x = x + self.res_conv_unit1(res)
206
+ x = self.res_conv_unit2(x)
207
+ x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
208
+ x = self.project(x)
209
+ return x
210
+
211
+
212
+ @HEADS.register_module()
213
+ class DPTHead(DepthBaseDecodeHead):
214
+ """Vision Transformers for Dense Prediction.
215
+ This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
216
+ Args:
217
+ embed_dims (int): The embed dimension of the ViT backbone.
218
+ Default: 768.
219
+ post_process_channels (List): Out channels of post process conv
220
+ layers. Default: [96, 192, 384, 768].
221
+ readout_type (str): Type of readout operation. Default: 'ignore'.
222
+ patch_size (int): The patch size. Default: 16.
223
+ expand_channels (bool): Whether expand the channels in post process
224
+ block. Default: False.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ embed_dims=768,
230
+ post_process_channels=[96, 192, 384, 768],
231
+ readout_type="ignore",
232
+ patch_size=16,
233
+ expand_channels=False,
234
+ **kwargs
235
+ ):
236
+ super(DPTHead, self).__init__(**kwargs)
237
+
238
+ self.in_channels = self.in_channels
239
+ self.expand_channels = expand_channels
240
+ self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
241
+
242
+ self.post_process_channels = [
243
+ channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
244
+ ]
245
+ self.convs = nn.ModuleList()
246
+ for channel in self.post_process_channels:
247
+ self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False))
248
+ self.fusion_blocks = nn.ModuleList()
249
+ for _ in range(len(self.convs)):
250
+ self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg))
251
+ self.fusion_blocks[0].res_conv_unit1 = None
252
+ self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg)
253
+ self.num_fusion_blocks = len(self.fusion_blocks)
254
+ self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
255
+ self.num_post_process_channels = len(self.post_process_channels)
256
+ assert self.num_fusion_blocks == self.num_reassemble_blocks
257
+ assert self.num_reassemble_blocks == self.num_post_process_channels
258
+ self.conv_depth = HeadDepth(self.channels)
259
+
260
+ def forward(self, inputs, img_metas):
261
+ assert len(inputs) == self.num_reassemble_blocks
262
+ x = [inp for inp in inputs]
263
+ x = self.reassemble_blocks(x)
264
+ x = [self.convs[i](feature) for i, feature in enumerate(x)]
265
+ out = self.fusion_blocks[0](x[-1])
266
+ for i in range(1, len(self.fusion_blocks)):
267
+ out = self.fusion_blocks[i](out, x[-(i + 1)])
268
+ out = self.project(out)
269
+ out = self.depth_pred(out)
270
+ return out
src/dinov2/eval/depth/models/decode_heads/linear_head.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from ...ops import resize
10
+ from ..builder import HEADS
11
+ from .decode_head import DepthBaseDecodeHead
12
+
13
+
14
+ @HEADS.register_module()
15
+ class BNHead(DepthBaseDecodeHead):
16
+ """Just a batchnorm."""
17
+
18
+ def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.input_transform = input_transform
21
+ self.in_index = in_index
22
+ self.upsample = upsample
23
+ # self.bn = nn.SyncBatchNorm(self.in_channels)
24
+ if self.classify:
25
+ self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
26
+ else:
27
+ self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
28
+
29
+ def _transform_inputs(self, inputs):
30
+ """Transform inputs for decoder.
31
+ Args:
32
+ inputs (list[Tensor]): List of multi-level img features.
33
+ Returns:
34
+ Tensor: The transformed inputs
35
+ """
36
+
37
+ if "concat" in self.input_transform:
38
+ inputs = [inputs[i] for i in self.in_index]
39
+ if "resize" in self.input_transform:
40
+ inputs = [
41
+ resize(
42
+ input=x,
43
+ size=[s * self.upsample for s in inputs[0].shape[2:]],
44
+ mode="bilinear",
45
+ align_corners=self.align_corners,
46
+ )
47
+ for x in inputs
48
+ ]
49
+ inputs = torch.cat(inputs, dim=1)
50
+ elif self.input_transform == "multiple_select":
51
+ inputs = [inputs[i] for i in self.in_index]
52
+ else:
53
+ inputs = inputs[self.in_index]
54
+
55
+ return inputs
56
+
57
+ def _forward_feature(self, inputs, img_metas=None, **kwargs):
58
+ """Forward function for feature maps before classifying each pixel with
59
+ ``self.cls_seg`` fc.
60
+ Args:
61
+ inputs (list[Tensor]): List of multi-level img features.
62
+ Returns:
63
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
64
+ H, W) which is feature map for last layer of decoder head.
65
+ """
66
+ # accept lists (for cls token)
67
+ inputs = list(inputs)
68
+ for i, x in enumerate(inputs):
69
+ if len(x) == 2:
70
+ x, cls_token = x[0], x[1]
71
+ if len(x.shape) == 2:
72
+ x = x[:, :, None, None]
73
+ cls_token = cls_token[:, :, None, None].expand_as(x)
74
+ inputs[i] = torch.cat((x, cls_token), 1)
75
+ else:
76
+ x = x[0]
77
+ if len(x.shape) == 2:
78
+ x = x[:, :, None, None]
79
+ inputs[i] = x
80
+ x = self._transform_inputs(inputs)
81
+ # feats = self.bn(x)
82
+ return x
83
+
84
+ def forward(self, inputs, img_metas=None, **kwargs):
85
+ """Forward function."""
86
+ output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
87
+ output = self.depth_pred(output)
88
+
89
+ return output
src/dinov2/eval/depth/models/depther/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .base import BaseDepther
7
+ from .encoder_decoder import DepthEncoderDecoder
src/dinov2/eval/depth/models/depther/base.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from abc import ABCMeta, abstractmethod
7
+ from collections import OrderedDict
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+ from mmcv.runner import BaseModule, auto_fp16
12
+
13
+
14
+ class BaseDepther(BaseModule, metaclass=ABCMeta):
15
+ """Base class for depther."""
16
+
17
+ def __init__(self, init_cfg=None):
18
+ super(BaseDepther, self).__init__(init_cfg)
19
+ self.fp16_enabled = False
20
+
21
+ @property
22
+ def with_neck(self):
23
+ """bool: whether the depther has neck"""
24
+ return hasattr(self, "neck") and self.neck is not None
25
+
26
+ @property
27
+ def with_auxiliary_head(self):
28
+ """bool: whether the depther has auxiliary head"""
29
+ return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None
30
+
31
+ @property
32
+ def with_decode_head(self):
33
+ """bool: whether the depther has decode head"""
34
+ return hasattr(self, "decode_head") and self.decode_head is not None
35
+
36
+ @abstractmethod
37
+ def extract_feat(self, imgs):
38
+ """Placeholder for extract features from images."""
39
+ pass
40
+
41
+ @abstractmethod
42
+ def encode_decode(self, img, img_metas):
43
+ """Placeholder for encode images with backbone and decode into a
44
+ semantic depth map of the same size as input."""
45
+ pass
46
+
47
+ @abstractmethod
48
+ def forward_train(self, imgs, img_metas, **kwargs):
49
+ """Placeholder for Forward function for training."""
50
+ pass
51
+
52
+ @abstractmethod
53
+ def simple_test(self, img, img_meta, **kwargs):
54
+ """Placeholder for single image test."""
55
+ pass
56
+
57
+ @abstractmethod
58
+ def aug_test(self, imgs, img_metas, **kwargs):
59
+ """Placeholder for augmentation test."""
60
+ pass
61
+
62
+ def forward_test(self, imgs, img_metas, **kwargs):
63
+ """
64
+ Args:
65
+ imgs (List[Tensor]): the outer list indicates test-time
66
+ augmentations and inner Tensor should have a shape NxCxHxW,
67
+ which contains all images in the batch.
68
+ img_metas (List[List[dict]]): the outer list indicates test-time
69
+ augs (multiscale, flip, etc.) and the inner list indicates
70
+ images in a batch.
71
+ """
72
+ for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
73
+ if not isinstance(var, list):
74
+ raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
75
+ num_augs = len(imgs)
76
+ if num_augs != len(img_metas):
77
+ raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
78
+ # all images in the same aug batch all of the same ori_shape and pad
79
+ # shape
80
+ for img_meta in img_metas:
81
+ ori_shapes = [_["ori_shape"] for _ in img_meta]
82
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
83
+ img_shapes = [_["img_shape"] for _ in img_meta]
84
+ assert all(shape == img_shapes[0] for shape in img_shapes)
85
+ pad_shapes = [_["pad_shape"] for _ in img_meta]
86
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
87
+
88
+ if num_augs == 1:
89
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
90
+ else:
91
+ return self.aug_test(imgs, img_metas, **kwargs)
92
+
93
+ @auto_fp16(apply_to=("img",))
94
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
95
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
96
+ on whether ``return_loss`` is ``True``.
97
+
98
+ Note this setting will change the expected inputs. When
99
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
100
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
101
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
102
+ the outer list indicating test time augmentations.
103
+ """
104
+ if return_loss:
105
+ return self.forward_train(img, img_metas, **kwargs)
106
+ else:
107
+ return self.forward_test(img, img_metas, **kwargs)
108
+
109
+ def train_step(self, data_batch, optimizer, **kwargs):
110
+ """The iteration step during training.
111
+
112
+ This method defines an iteration step during training, except for the
113
+ back propagation and optimizer updating, which are done in an optimizer
114
+ hook. Note that in some complicated cases or models, the whole process
115
+ including back propagation and optimizer updating is also defined in
116
+ this method, such as GAN.
117
+
118
+ Args:
119
+ data (dict): The output of dataloader.
120
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
121
+ runner is passed to ``train_step()``. This argument is unused
122
+ and reserved.
123
+
124
+ Returns:
125
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
126
+ ``num_samples``.
127
+ ``loss`` is a tensor for back propagation, which can be a
128
+ weighted sum of multiple losses.
129
+ ``log_vars`` contains all the variables to be sent to the
130
+ logger.
131
+ ``num_samples`` indicates the batch size (when the model is
132
+ DDP, it means the batch size on each GPU), which is used for
133
+ averaging the logs.
134
+ """
135
+ losses = self(**data_batch)
136
+
137
+ # split losses and images
138
+ real_losses = {}
139
+ log_imgs = {}
140
+ for k, v in losses.items():
141
+ if "img" in k:
142
+ log_imgs[k] = v
143
+ else:
144
+ real_losses[k] = v
145
+
146
+ loss, log_vars = self._parse_losses(real_losses)
147
+
148
+ outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
149
+
150
+ return outputs
151
+
152
+ def val_step(self, data_batch, **kwargs):
153
+ """The iteration step during validation.
154
+
155
+ This method shares the same signature as :func:`train_step`, but used
156
+ during val epochs. Note that the evaluation after training epochs is
157
+ not implemented with this method, but an evaluation hook.
158
+ """
159
+ output = self(**data_batch, **kwargs)
160
+ return output
161
+
162
+ @staticmethod
163
+ def _parse_losses(losses):
164
+ """Parse the raw outputs (losses) of the network.
165
+
166
+ Args:
167
+ losses (dict): Raw output of the network, which usually contain
168
+ losses and other necessary information.
169
+
170
+ Returns:
171
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
172
+ which may be a weighted sum of all losses, log_vars contains
173
+ all the variables to be sent to the logger.
174
+ """
175
+ log_vars = OrderedDict()
176
+ for loss_name, loss_value in losses.items():
177
+ if isinstance(loss_value, torch.Tensor):
178
+ log_vars[loss_name] = loss_value.mean()
179
+ elif isinstance(loss_value, list):
180
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
181
+ else:
182
+ raise TypeError(f"{loss_name} is not a tensor or list of tensors")
183
+
184
+ loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
185
+
186
+ log_vars["loss"] = loss
187
+ for loss_name, loss_value in log_vars.items():
188
+ # reduce loss when distributed training
189
+ if dist.is_available() and dist.is_initialized():
190
+ loss_value = loss_value.data.clone()
191
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
192
+ log_vars[loss_name] = loss_value.item()
193
+
194
+ return loss, log_vars
src/dinov2/eval/depth/models/depther/encoder_decoder.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from ...models import builder
10
+ from ...models.builder import DEPTHER
11
+ from ...ops import resize
12
+ from .base import BaseDepther
13
+
14
+
15
+ def add_prefix(inputs, prefix):
16
+ """Add prefix for dict.
17
+
18
+ Args:
19
+ inputs (dict): The input dict with str keys.
20
+ prefix (str): The prefix to add.
21
+
22
+ Returns:
23
+
24
+ dict: The dict with keys updated with ``prefix``.
25
+ """
26
+
27
+ outputs = dict()
28
+ for name, value in inputs.items():
29
+ outputs[f"{prefix}.{name}"] = value
30
+
31
+ return outputs
32
+
33
+
34
+ @DEPTHER.register_module()
35
+ class DepthEncoderDecoder(BaseDepther):
36
+ """Encoder Decoder depther.
37
+
38
+ EncoderDecoder typically consists of backbone, (neck) and decode_head.
39
+ """
40
+
41
+ def __init__(self, backbone, decode_head, neck=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None):
42
+ super(DepthEncoderDecoder, self).__init__(init_cfg)
43
+ if pretrained is not None:
44
+ assert backbone.get("pretrained") is None, "both backbone and depther set pretrained weight"
45
+ backbone.pretrained = pretrained
46
+ self.backbone = builder.build_backbone(backbone)
47
+ self._init_decode_head(decode_head)
48
+
49
+ if neck is not None:
50
+ self.neck = builder.build_neck(neck)
51
+
52
+ self.train_cfg = train_cfg
53
+ self.test_cfg = test_cfg
54
+
55
+ assert self.with_decode_head
56
+
57
+ def _init_decode_head(self, decode_head):
58
+ """Initialize ``decode_head``"""
59
+ self.decode_head = builder.build_head(decode_head)
60
+ self.align_corners = self.decode_head.align_corners
61
+
62
+ def extract_feat(self, img):
63
+ """Extract features from images."""
64
+ x = self.backbone(img)
65
+ if self.with_neck:
66
+ x = self.neck(x)
67
+ return x
68
+
69
+ def encode_decode(self, img, img_metas, rescale=True, size=None):
70
+ """Encode images with backbone and decode into a depth estimation
71
+ map of the same size as input."""
72
+ x = self.extract_feat(img)
73
+ out = self._decode_head_forward_test(x, img_metas)
74
+ # crop the pred depth to the certain range.
75
+ out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
76
+ if rescale:
77
+ if size is None:
78
+ if img_metas is not None:
79
+ size = img_metas[0]["ori_shape"][:2]
80
+ else:
81
+ size = img.shape[2:]
82
+ out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
83
+ return out
84
+
85
+ def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
86
+ """Run forward function and calculate loss for decode head in
87
+ training."""
88
+ losses = dict()
89
+ loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs)
90
+ losses.update(add_prefix(loss_decode, "decode"))
91
+ return losses
92
+
93
+ def _decode_head_forward_test(self, x, img_metas):
94
+ """Run forward function and calculate loss for decode head in
95
+ inference."""
96
+ depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg)
97
+ return depth_pred
98
+
99
+ def forward_dummy(self, img):
100
+ """Dummy forward function."""
101
+ depth = self.encode_decode(img, None)
102
+
103
+ return depth
104
+
105
+ def forward_train(self, img, img_metas, depth_gt, **kwargs):
106
+ """Forward function for training.
107
+
108
+ Args:
109
+ img (Tensor): Input images.
110
+ img_metas (list[dict]): List of image info dict where each dict
111
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
112
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
113
+ For details on the values of these keys see
114
+ `depth/datasets/pipelines/formatting.py:Collect`.
115
+ depth_gt (Tensor): Depth gt
116
+ used if the architecture supports depth estimation task.
117
+
118
+ Returns:
119
+ dict[str, Tensor]: a dictionary of loss components
120
+ """
121
+
122
+ x = self.extract_feat(img)
123
+
124
+ losses = dict()
125
+
126
+ # the last of x saves the info from neck
127
+ loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
128
+
129
+ losses.update(loss_decode)
130
+
131
+ return losses
132
+
133
+ def whole_inference(self, img, img_meta, rescale, size=None):
134
+ """Inference with full image."""
135
+ depth_pred = self.encode_decode(img, img_meta, rescale, size=size)
136
+
137
+ return depth_pred
138
+
139
+ def slide_inference(self, img, img_meta, rescale):
140
+ """Inference by sliding-window with overlap.
141
+
142
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
143
+ decode without padding.
144
+ """
145
+
146
+ h_stride, w_stride = self.test_cfg.stride
147
+ h_crop, w_crop = self.test_cfg.crop_size
148
+ batch_size, _, h_img, w_img = img.size()
149
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
150
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
151
+ preds = img.new_zeros((batch_size, 1, h_img, w_img))
152
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
153
+ for h_idx in range(h_grids):
154
+ for w_idx in range(w_grids):
155
+ y1 = h_idx * h_stride
156
+ x1 = w_idx * w_stride
157
+ y2 = min(y1 + h_crop, h_img)
158
+ x2 = min(x1 + w_crop, w_img)
159
+ y1 = max(y2 - h_crop, 0)
160
+ x1 = max(x2 - w_crop, 0)
161
+ crop_img = img[:, :, y1:y2, x1:x2]
162
+ depth_pred = self.encode_decode(crop_img, img_meta, rescale)
163
+ preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
164
+
165
+ count_mat[:, :, y1:y2, x1:x2] += 1
166
+ assert (count_mat == 0).sum() == 0
167
+ if torch.onnx.is_in_onnx_export():
168
+ # cast count_mat to constant while exporting to ONNX
169
+ count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
170
+ preds = preds / count_mat
171
+ return preds
172
+
173
+ def inference(self, img, img_meta, rescale, size=None):
174
+ """Inference with slide/whole style.
175
+
176
+ Args:
177
+ img (Tensor): The input image of shape (N, 3, H, W).
178
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
179
+ 'scale_factor', 'flip', and may also contain
180
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
181
+ For details on the values of these keys see
182
+ `depth/datasets/pipelines/formatting.py:Collect`.
183
+ rescale (bool): Whether rescale back to original shape.
184
+
185
+ Returns:
186
+ Tensor: The output depth map.
187
+ """
188
+
189
+ assert self.test_cfg.mode in ["slide", "whole"]
190
+ ori_shape = img_meta[0]["ori_shape"]
191
+ assert all(_["ori_shape"] == ori_shape for _ in img_meta)
192
+ if self.test_cfg.mode == "slide":
193
+ depth_pred = self.slide_inference(img, img_meta, rescale)
194
+ else:
195
+ depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
196
+ output = depth_pred
197
+ flip = img_meta[0]["flip"]
198
+ if flip:
199
+ flip_direction = img_meta[0]["flip_direction"]
200
+ assert flip_direction in ["horizontal", "vertical"]
201
+ if flip_direction == "horizontal":
202
+ output = output.flip(dims=(3,))
203
+ elif flip_direction == "vertical":
204
+ output = output.flip(dims=(2,))
205
+
206
+ return output
207
+
208
+ def simple_test(self, img, img_meta, rescale=True):
209
+ """Simple test with single image."""
210
+ depth_pred = self.inference(img, img_meta, rescale)
211
+ if torch.onnx.is_in_onnx_export():
212
+ # our inference backend only support 4D output
213
+ depth_pred = depth_pred.unsqueeze(0)
214
+ return depth_pred
215
+ depth_pred = depth_pred.cpu().numpy()
216
+ # unravel batch dim
217
+ depth_pred = list(depth_pred)
218
+ return depth_pred
219
+
220
+ def aug_test(self, imgs, img_metas, rescale=True):
221
+ """Test with augmentations.
222
+
223
+ Only rescale=True is supported.
224
+ """
225
+ # aug_test rescale all imgs back to ori_shape for now
226
+ assert rescale
227
+ # to save memory, we get augmented depth logit inplace
228
+ depth_pred = self.inference(imgs[0], img_metas[0], rescale)
229
+ for i in range(1, len(imgs)):
230
+ cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
231
+ depth_pred += cur_depth_pred
232
+ depth_pred /= len(imgs)
233
+ depth_pred = depth_pred.cpu().numpy()
234
+ # unravel batch dim
235
+ depth_pred = list(depth_pred)
236
+ return depth_pred
src/dinov2/eval/depth/models/losses/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .gradientloss import GradientLoss
7
+ from .sigloss import SigLoss
src/dinov2/eval/depth/models/losses/gradientloss.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from ...models.builder import LOSSES
10
+
11
+
12
+ @LOSSES.register_module()
13
+ class GradientLoss(nn.Module):
14
+ """GradientLoss.
15
+
16
+ Adapted from https://www.cs.cornell.edu/projects/megadepth/
17
+
18
+ Args:
19
+ valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
20
+ loss_weight (float): Weight of the loss. Default: 1.0.
21
+ max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
22
+ """
23
+
24
+ def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"):
25
+ super(GradientLoss, self).__init__()
26
+ self.valid_mask = valid_mask
27
+ self.loss_weight = loss_weight
28
+ self.max_depth = max_depth
29
+ self.loss_name = loss_name
30
+
31
+ self.eps = 0.001 # avoid grad explode
32
+
33
+ def gradientloss(self, input, target):
34
+ input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)]
35
+ target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)]
36
+
37
+ gradient_loss = 0
38
+ for input, target in zip(input_downscaled, target_downscaled):
39
+ if self.valid_mask:
40
+ mask = target > 0
41
+ if self.max_depth is not None:
42
+ mask = torch.logical_and(target > 0, target <= self.max_depth)
43
+ N = torch.sum(mask)
44
+ else:
45
+ mask = torch.ones_like(target)
46
+ N = input.numel()
47
+ input_log = torch.log(input + self.eps)
48
+ target_log = torch.log(target + self.eps)
49
+ log_d_diff = input_log - target_log
50
+
51
+ log_d_diff = torch.mul(log_d_diff, mask)
52
+
53
+ v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :])
54
+ v_mask = torch.mul(mask[0:-2, :], mask[2:, :])
55
+ v_gradient = torch.mul(v_gradient, v_mask)
56
+
57
+ h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:])
58
+ h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:])
59
+ h_gradient = torch.mul(h_gradient, h_mask)
60
+
61
+ gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N
62
+
63
+ return gradient_loss
64
+
65
+ def forward(self, depth_pred, depth_gt):
66
+ """Forward function."""
67
+
68
+ gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt)
69
+ return gradient_loss
src/dinov2/eval/depth/models/losses/sigloss.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from ...models.builder import LOSSES
10
+
11
+
12
+ @LOSSES.register_module()
13
+ class SigLoss(nn.Module):
14
+ """SigLoss.
15
+
16
+ This follows `AdaBins <https://arxiv.org/abs/2011.14141>`_.
17
+
18
+ Args:
19
+ valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
20
+ loss_weight (float): Weight of the loss. Default: 1.0.
21
+ max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
22
+ warm_up (bool): A simple warm up stage to help convergence. Default: False.
23
+ warm_iter (int): The number of warm up stage. Default: 100.
24
+ """
25
+
26
+ def __init__(
27
+ self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss"
28
+ ):
29
+ super(SigLoss, self).__init__()
30
+ self.valid_mask = valid_mask
31
+ self.loss_weight = loss_weight
32
+ self.max_depth = max_depth
33
+ self.loss_name = loss_name
34
+
35
+ self.eps = 0.001 # avoid grad explode
36
+
37
+ # HACK: a hack implementation for warmup sigloss
38
+ self.warm_up = warm_up
39
+ self.warm_iter = warm_iter
40
+ self.warm_up_counter = 0
41
+
42
+ def sigloss(self, input, target):
43
+ if self.valid_mask:
44
+ valid_mask = target > 0
45
+ if self.max_depth is not None:
46
+ valid_mask = torch.logical_and(target > 0, target <= self.max_depth)
47
+ input = input[valid_mask]
48
+ target = target[valid_mask]
49
+
50
+ if self.warm_up:
51
+ if self.warm_up_counter < self.warm_iter:
52
+ g = torch.log(input + self.eps) - torch.log(target + self.eps)
53
+ g = 0.15 * torch.pow(torch.mean(g), 2)
54
+ self.warm_up_counter += 1
55
+ return torch.sqrt(g)
56
+
57
+ g = torch.log(input + self.eps) - torch.log(target + self.eps)
58
+ Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2)
59
+ return torch.sqrt(Dg)
60
+
61
+ def forward(self, depth_pred, depth_gt):
62
+ """Forward function."""
63
+
64
+ loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt)
65
+ return loss_depth
src/dinov2/eval/depth/ops/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .wrappers import resize
src/dinov2/eval/depth/ops/wrappers.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import warnings
7
+
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
12
+ if warning:
13
+ if size is not None and align_corners:
14
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
15
+ output_h, output_w = tuple(int(x) for x in size)
16
+ if output_h > input_h or output_w > output_h:
17
+ if (
18
+ (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
19
+ and (output_h - 1) % (input_h - 1)
20
+ and (output_w - 1) % (input_w - 1)
21
+ ):
22
+ warnings.warn(
23
+ f"When align_corners={align_corners}, "
24
+ "the output would more aligned if "
25
+ f"input size {(input_h, input_w)} is `x+1` and "
26
+ f"out size {(output_h, output_w)} is `nx+1`"
27
+ )
28
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
src/dinov2/eval/knn.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ from functools import partial
8
+ import json
9
+ import logging
10
+ import os
11
+ import sys
12
+ from typing import List, Optional
13
+
14
+ import torch
15
+ from torch.nn.functional import one_hot, softmax
16
+
17
+ import dinov2.distributed as distributed
18
+ from dinov2.data import SamplerType, make_data_loader, make_dataset
19
+ from dinov2.data.transforms import make_classification_eval_transform
20
+ from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric
21
+ from dinov2.eval.setup import get_args_parser as get_setup_args_parser
22
+ from dinov2.eval.setup import setup_and_build_model
23
+ from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features
24
+
25
+
26
+ logger = logging.getLogger("dinov2")
27
+
28
+
29
+ def get_args_parser(
30
+ description: Optional[str] = None,
31
+ parents: Optional[List[argparse.ArgumentParser]] = None,
32
+ add_help: bool = True,
33
+ ):
34
+ parents = parents or []
35
+ setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
36
+ parents = [setup_args_parser]
37
+ parser = argparse.ArgumentParser(
38
+ description=description,
39
+ parents=parents,
40
+ add_help=add_help,
41
+ )
42
+ parser.add_argument(
43
+ "--train-dataset",
44
+ dest="train_dataset_str",
45
+ type=str,
46
+ help="Training dataset",
47
+ )
48
+ parser.add_argument(
49
+ "--val-dataset",
50
+ dest="val_dataset_str",
51
+ type=str,
52
+ help="Validation dataset",
53
+ )
54
+ parser.add_argument(
55
+ "--nb_knn",
56
+ nargs="+",
57
+ type=int,
58
+ help="Number of NN to use. 20 is usually working the best.",
59
+ )
60
+ parser.add_argument(
61
+ "--temperature",
62
+ type=float,
63
+ help="Temperature used in the voting coefficient",
64
+ )
65
+ parser.add_argument(
66
+ "--gather-on-cpu",
67
+ action="store_true",
68
+ help="Whether to gather the train features on cpu, slower"
69
+ "but useful to avoid OOM for large datasets (e.g. ImageNet22k).",
70
+ )
71
+ parser.add_argument(
72
+ "--batch-size",
73
+ type=int,
74
+ help="Batch size.",
75
+ )
76
+ parser.add_argument(
77
+ "--n-per-class-list",
78
+ nargs="+",
79
+ type=int,
80
+ help="Number to take per class",
81
+ )
82
+ parser.add_argument(
83
+ "--n-tries",
84
+ type=int,
85
+ help="Number of tries",
86
+ )
87
+ parser.set_defaults(
88
+ train_dataset_str="ImageNet:split=TRAIN",
89
+ val_dataset_str="ImageNet:split=VAL",
90
+ nb_knn=[10, 20, 100, 200],
91
+ temperature=0.07,
92
+ batch_size=256,
93
+ n_per_class_list=[-1],
94
+ n_tries=1,
95
+ )
96
+ return parser
97
+
98
+
99
+ class KnnModule(torch.nn.Module):
100
+ """
101
+ Gets knn of test features from all processes on a chunk of the train features
102
+
103
+ Each rank gets a chunk of the train features as well as a chunk of the test features.
104
+ In `compute_neighbors`, for each rank one after the other, its chunk of test features
105
+ is sent to all devices, partial knns are computed with each chunk of train features
106
+ then collated back on the original device.
107
+ """
108
+
109
+ def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000):
110
+ super().__init__()
111
+
112
+ self.global_rank = distributed.get_global_rank()
113
+ self.global_size = distributed.get_global_size()
114
+
115
+ self.device = device
116
+ self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device)
117
+ self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device)
118
+
119
+ self.nb_knn = nb_knn
120
+ self.max_k = max(self.nb_knn)
121
+ self.T = T
122
+ self.num_classes = num_classes
123
+
124
+ def _get_knn_sims_and_labels(self, similarity, train_labels):
125
+ topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True)
126
+ neighbors_labels = torch.gather(train_labels, 1, indices)
127
+ return topk_sims, neighbors_labels
128
+
129
+ def _similarity_for_rank(self, features_rank, source_rank):
130
+ # Send the features from `source_rank` to all ranks
131
+ broadcast_shape = torch.tensor(features_rank.shape).to(self.device)
132
+ torch.distributed.broadcast(broadcast_shape, source_rank)
133
+
134
+ broadcasted = features_rank
135
+ if self.global_rank != source_rank:
136
+ broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device)
137
+ torch.distributed.broadcast(broadcasted, source_rank)
138
+
139
+ # Compute the neighbors for `source_rank` among `train_features_rank_T`
140
+ similarity_rank = torch.mm(broadcasted, self.train_features_rank_T)
141
+ candidate_labels = self.candidates.expand(len(similarity_rank), -1)
142
+ return self._get_knn_sims_and_labels(similarity_rank, candidate_labels)
143
+
144
+ def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank):
145
+ # Gather all neighbors for `target_rank`
146
+ topk_sims_rank = retrieved_rank = None
147
+ if self.global_rank == target_rank:
148
+ topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)]
149
+ retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)]
150
+
151
+ torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank)
152
+ torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank)
153
+
154
+ if self.global_rank == target_rank:
155
+ # Perform a second top-k on the k * global_size retrieved neighbors
156
+ topk_sims_rank = torch.cat(topk_sims_rank, dim=1)
157
+ retrieved_rank = torch.cat(retrieved_rank, dim=1)
158
+ results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank)
159
+ return results
160
+ return None
161
+
162
+ def compute_neighbors(self, features_rank):
163
+ for rank in range(self.global_size):
164
+ topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank)
165
+ results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank)
166
+ if results is not None:
167
+ topk_sims_rank, neighbors_labels_rank = results
168
+ return topk_sims_rank, neighbors_labels_rank
169
+
170
+ def forward(self, features_rank):
171
+ """
172
+ Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k`
173
+ """
174
+ assert all(k <= self.max_k for k in self.nb_knn)
175
+
176
+ topk_sims, neighbors_labels = self.compute_neighbors(features_rank)
177
+ batch_size = neighbors_labels.shape[0]
178
+ topk_sims_transform = softmax(topk_sims / self.T, 1)
179
+ matmul = torch.mul(
180
+ one_hot(neighbors_labels, num_classes=self.num_classes),
181
+ topk_sims_transform.view(batch_size, -1, 1),
182
+ )
183
+ probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn}
184
+ return probas_for_k
185
+
186
+
187
+ class DictKeysModule(torch.nn.Module):
188
+ def __init__(self, keys):
189
+ super().__init__()
190
+ self.keys = keys
191
+
192
+ def forward(self, features_dict, targets):
193
+ for k in self.keys:
194
+ features_dict = features_dict[k]
195
+ return {"preds": features_dict, "target": targets}
196
+
197
+
198
+ def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels):
199
+ modules = {}
200
+ mapping = create_class_indices_mapping(train_labels)
201
+ for npc in n_per_class_list:
202
+ if npc < 0: # Only one try needed when using the full data
203
+ full_module = module(
204
+ train_features=train_features,
205
+ train_labels=train_labels,
206
+ nb_knn=nb_knn,
207
+ )
208
+ modules["full"] = ModuleDictWithForward({"1": full_module})
209
+ continue
210
+ all_tries = {}
211
+ for t in range(n_tries):
212
+ final_indices = filter_train(mapping, npc, seed=t)
213
+ k_list = list(set(nb_knn + [npc]))
214
+ k_list = sorted([el for el in k_list if el <= npc])
215
+ all_tries[str(t)] = module(
216
+ train_features=train_features[final_indices],
217
+ train_labels=train_labels[final_indices],
218
+ nb_knn=k_list,
219
+ )
220
+ modules[f"{npc} per class"] = ModuleDictWithForward(all_tries)
221
+
222
+ return ModuleDictWithForward(modules)
223
+
224
+
225
+ def filter_train(mapping, n_per_class, seed):
226
+ torch.manual_seed(seed)
227
+ final_indices = []
228
+ for k in mapping.keys():
229
+ index = torch.randperm(len(mapping[k]))[:n_per_class]
230
+ final_indices.append(mapping[k][index])
231
+ return torch.cat(final_indices).squeeze()
232
+
233
+
234
+ def create_class_indices_mapping(labels):
235
+ unique_labels, inverse = torch.unique(labels, return_inverse=True)
236
+ mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))}
237
+ return mapping
238
+
239
+
240
+ class ModuleDictWithForward(torch.nn.ModuleDict):
241
+ def forward(self, *args, **kwargs):
242
+ return {k: module(*args, **kwargs) for k, module in self._modules.items()}
243
+
244
+
245
+ def eval_knn(
246
+ model,
247
+ train_dataset,
248
+ val_dataset,
249
+ accuracy_averaging,
250
+ nb_knn,
251
+ temperature,
252
+ batch_size,
253
+ num_workers,
254
+ gather_on_cpu,
255
+ n_per_class_list=[-1],
256
+ n_tries=1,
257
+ ):
258
+ model = ModelWithNormalize(model)
259
+
260
+ logger.info("Extracting features for train set...")
261
+ train_features, train_labels = extract_features(
262
+ model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu
263
+ )
264
+ logger.info(f"Train features created, shape {train_features.shape}.")
265
+
266
+ val_dataloader = make_data_loader(
267
+ dataset=val_dataset,
268
+ batch_size=batch_size,
269
+ num_workers=num_workers,
270
+ sampler_type=SamplerType.DISTRIBUTED,
271
+ drop_last=False,
272
+ shuffle=False,
273
+ persistent_workers=True,
274
+ )
275
+ num_classes = train_labels.max() + 1
276
+ metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes)
277
+
278
+ device = torch.cuda.current_device()
279
+ partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes)
280
+ knn_module_dict = create_module_dict(
281
+ module=partial_module,
282
+ n_per_class_list=n_per_class_list,
283
+ n_tries=n_tries,
284
+ nb_knn=nb_knn,
285
+ train_features=train_features,
286
+ train_labels=train_labels,
287
+ )
288
+ postprocessors, metrics = {}, {}
289
+ for n_per_class, knn_module in knn_module_dict.items():
290
+ for t, knn_try in knn_module.items():
291
+ postprocessors = {
292
+ **postprocessors,
293
+ **{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn},
294
+ }
295
+ metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}}
296
+ model_with_knn = torch.nn.Sequential(model, knn_module_dict)
297
+
298
+ # ============ evaluation ... ============
299
+ logger.info("Start the k-NN classification.")
300
+ _, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device)
301
+
302
+ # Averaging the results over the n tries for each value of n_per_class
303
+ for n_per_class, knn_module in knn_module_dict.items():
304
+ first_try = list(knn_module.keys())[0]
305
+ k_list = knn_module[first_try].nb_knn
306
+ for k in k_list:
307
+ keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5`
308
+ results_dict[(n_per_class, k)] = {
309
+ key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()]))
310
+ for key in keys
311
+ }
312
+ for t in knn_module.keys():
313
+ del results_dict[(n_per_class, t, k)]
314
+
315
+ return results_dict
316
+
317
+
318
+ def eval_knn_with_model(
319
+ model,
320
+ output_dir,
321
+ train_dataset_str="ImageNet:split=TRAIN",
322
+ val_dataset_str="ImageNet:split=VAL",
323
+ nb_knn=(10, 20, 100, 200),
324
+ temperature=0.07,
325
+ autocast_dtype=torch.float,
326
+ accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
327
+ transform=None,
328
+ gather_on_cpu=False,
329
+ batch_size=256,
330
+ num_workers=5,
331
+ n_per_class_list=[-1],
332
+ n_tries=1,
333
+ ):
334
+ transform = transform or make_classification_eval_transform()
335
+
336
+ train_dataset = make_dataset(
337
+ dataset_str=train_dataset_str,
338
+ transform=transform,
339
+ )
340
+ val_dataset = make_dataset(
341
+ dataset_str=val_dataset_str,
342
+ transform=transform,
343
+ )
344
+
345
+ with torch.cuda.amp.autocast(dtype=autocast_dtype):
346
+ results_dict_knn = eval_knn(
347
+ model=model,
348
+ train_dataset=train_dataset,
349
+ val_dataset=val_dataset,
350
+ accuracy_averaging=accuracy_averaging,
351
+ nb_knn=nb_knn,
352
+ temperature=temperature,
353
+ batch_size=batch_size,
354
+ num_workers=num_workers,
355
+ gather_on_cpu=gather_on_cpu,
356
+ n_per_class_list=n_per_class_list,
357
+ n_tries=n_tries,
358
+ )
359
+
360
+ results_dict = {}
361
+ if distributed.is_main_process():
362
+ for knn_ in results_dict_knn.keys():
363
+ top1 = results_dict_knn[knn_]["top-1"].item() * 100.0
364
+ top5 = results_dict_knn[knn_]["top-5"].item() * 100.0
365
+ results_dict[f"{knn_} Top 1"] = top1
366
+ results_dict[f"{knn_} Top 5"] = top5
367
+ logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}")
368
+
369
+ metrics_file_path = os.path.join(output_dir, "results_eval_knn.json")
370
+ with open(metrics_file_path, "a") as f:
371
+ for k, v in results_dict.items():
372
+ f.write(json.dumps({k: v}) + "\n")
373
+
374
+ if distributed.is_enabled():
375
+ torch.distributed.barrier()
376
+ return results_dict
377
+
378
+
379
+ def main(args):
380
+ model, autocast_dtype = setup_and_build_model(args)
381
+ eval_knn_with_model(
382
+ model=model,
383
+ output_dir=args.output_dir,
384
+ train_dataset_str=args.train_dataset_str,
385
+ val_dataset_str=args.val_dataset_str,
386
+ nb_knn=args.nb_knn,
387
+ temperature=args.temperature,
388
+ autocast_dtype=autocast_dtype,
389
+ accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
390
+ transform=None,
391
+ gather_on_cpu=args.gather_on_cpu,
392
+ batch_size=args.batch_size,
393
+ num_workers=5,
394
+ n_per_class_list=args.n_per_class_list,
395
+ n_tries=args.n_tries,
396
+ )
397
+ return 0
398
+
399
+
400
+ if __name__ == "__main__":
401
+ description = "DINOv2 k-NN evaluation"
402
+ args_parser = get_args_parser(description=description)
403
+ args = args_parser.parse_args()
404
+ sys.exit(main(args))